Skip to content

Commit 35da928

Browse files
committed
refactor(NxDPreTrainedModel): regroup NeuronPreTrainedModel methods
1 parent 7a192fb commit 35da928

File tree

3 files changed

+145
-150
lines changed

3 files changed

+145
-150
lines changed

optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,18 @@
1414
# limitations under the License.
1515
import copy
1616
import logging
17-
import os
18-
from pathlib import Path
19-
from tempfile import TemporaryDirectory
2017

2118
import neuronx_distributed as nxd
2219
import torch
23-
from huggingface_hub import snapshot_download
2420
from neuronx_distributed.operators.argmax import argmax as nxd_argmax
2521
from neuronx_distributed.parallel_layers.layers import SPMDRank
2622
from neuronx_distributed.parallel_layers.mappings import (
2723
_gather_along_dim,
2824
)
2925
from torch import nn
30-
from transformers import AutoConfig, PretrainedConfig
26+
from transformers import PretrainedConfig
3127
from transformers.modeling_outputs import CausalLMOutputWithPast
3228

33-
from ......cache.entries.single_model import SingleModelCacheEntry
34-
from ......cache.hub_cache import hub_neuronx_cache
35-
from ......utils.instance import align_compilation_target, current_instance_type
36-
from ......utils.system import get_available_cores
3729
from ....modeling_utils import NeuronModelForCausalLM
3830
from ...config import NxDNeuronConfig
3931
from ...graph_builder import NxDGraphBuilder
@@ -574,135 +566,3 @@ def get_compiler_args(cls, neuron_config: NxDNeuronConfig) -> str:
574566

575567
logging.info(f"neuronx-cc compiler_args are: {compiler_args}")
576568
return compiler_args
577-
578-
# NeuronPreTrainedModel methods
579-
@classmethod
580-
def _from_pretrained(
581-
cls,
582-
model_id: "str | Path",
583-
config: "PretrainedConfig",
584-
revision: str | None = None,
585-
token: bool | str | None = None,
586-
cache_dir: str | None = None,
587-
force_download: bool | None = False,
588-
local_files_only: bool | None = False,
589-
**kwargs,
590-
) -> "NeuronModelForCausalLM":
591-
if len(kwargs) > 0:
592-
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
593-
neuron_config = NxDNeuronConfig.from_pretrained(model_id)
594-
# Check the current instance type is compatible with the one used to compile the model
595-
if neuron_config.target != current_instance_type():
596-
raise ValueError(
597-
f"The model was compiled for {neuron_config.target} but the current instance type is "
598-
f"{current_instance_type()}. Please use a compatible instance type."
599-
)
600-
# Also check the number of cores is at least equal to the tensor parallel size
601-
if get_available_cores() < neuron_config.tp_degree:
602-
raise ValueError(
603-
f"The model requires at least {neuron_config.tp_degree} Neuron cores but only "
604-
f"{get_available_cores()} are available. Please use a compatible instance type."
605-
)
606-
if not os.path.exists(model_id):
607-
# The model_id is a model hub id: download the model from the hub.
608-
with TemporaryDirectory() as tmpdir:
609-
snapshot_download(
610-
repo_id=model_id,
611-
revision=revision,
612-
cache_dir=cache_dir,
613-
local_dir=tmpdir,
614-
force_download=force_download,
615-
local_files_only=local_files_only,
616-
token=token,
617-
allow_patterns=[cls.COMPILED_MODEL_FILE_NAME],
618-
)
619-
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
620-
else:
621-
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
622-
graph_builders = NxDModelForCausalLM.create_graph_builders(config=config, neuron_config=neuron_config)
623-
model = cls(
624-
config=config,
625-
neuron_config=neuron_config,
626-
traced_model=traced_model,
627-
graph_builders=graph_builders,
628-
)
629-
model.load_weights(
630-
model_id,
631-
cache_dir=cache_dir,
632-
force_download=force_download,
633-
local_files_only=local_files_only,
634-
token=token,
635-
)
636-
return model
637-
638-
@classmethod
639-
def _export(
640-
cls,
641-
model_id: str,
642-
config: "PretrainedConfig | None",
643-
neuron_config: "NxDNeuronConfig",
644-
token: bool | str | None = None,
645-
revision: str | None = None,
646-
cache_dir: str | None = None,
647-
force_download: bool | None = False,
648-
local_files_only: bool | None = False,
649-
trust_remote_code: bool | None = False,
650-
load_weights: bool | None = False,
651-
**kwargs,
652-
) -> "NeuronModelForCausalLM":
653-
if len(kwargs) > 0:
654-
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
655-
# Try to align compilation target. We do not allow override as neuronx-distributed is already initialized.
656-
compilation_target = align_compilation_target(neuron_config.target, override=False)
657-
if compilation_target != neuron_config.target:
658-
raise ValueError(
659-
f"The compilation target is {neuron_config.target} but the NEURON_PLATFORM_TARGET_OVERRIDE"
660-
f" environment variable is set to {compilation_target}, Please set it to the correct value."
661-
)
662-
if config is None:
663-
# Get the text config if not provided
664-
config = AutoConfig.from_pretrained(
665-
model_id,
666-
token=token,
667-
revision=revision,
668-
cache_dir=cache_dir,
669-
force_download=force_download,
670-
trust_remote_code=trust_remote_code,
671-
).get_text_config()
672-
# Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
673-
config.torch_dtype = neuron_config.torch_dtype
674-
# Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+)
675-
if hasattr(config, "head_dim") and config.head_dim is None:
676-
config.head_dim = config.hidden_size // config.num_attention_heads
677-
graph_builders = cls.create_graph_builders(
678-
config=config,
679-
neuron_config=neuron_config,
680-
)
681-
# The model NEFF files will be cached locally, but if the model_id corresponds
682-
# to a hub model, we also create a cache entry for it.
683-
cache_entry = (
684-
None
685-
if os.path.exists(model_id)
686-
else SingleModelCacheEntry(model_id, task="text-generation", config=config, neuron_config=neuron_config)
687-
)
688-
with hub_neuronx_cache(entry=cache_entry):
689-
traced_model = NxDPreTrainedModel.compile(
690-
neuron_config=neuron_config,
691-
graph_builders=graph_builders,
692-
compiler_args=cls.get_compiler_args(neuron_config),
693-
)
694-
model = cls(
695-
config=config,
696-
neuron_config=neuron_config,
697-
traced_model=traced_model,
698-
graph_builders=graph_builders,
699-
)
700-
if load_weights:
701-
model.load_weights(
702-
model_id,
703-
cache_dir=cache_dir,
704-
force_download=force_download,
705-
local_files_only=local_files_only,
706-
token=token,
707-
)
708-
return model

optimum/neuron/models/inference/backend/modules/generation/generation_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
from abc import ABC, abstractmethod
1617
from typing import Any
1718

1819
import torch
@@ -28,7 +29,7 @@
2829
)
2930

3031

31-
class NxDGenerationMixin(GenerationMixin):
32+
class NxDGenerationMixin(GenerationMixin, ABC):
3233
"""A generation Mixin that can be used to extend NxDPreTrainedModel based classes"""
3334

3435
# These are expected to be set by the GenerationMixin code
@@ -425,3 +426,7 @@ def device(self) -> torch.device:
425426
"""
426427
# We dont want HF to move parameters to device
427428
return torch.device("cpu")
429+
430+
@abstractmethod
431+
def reset(self):
432+
raise SystemError(f"The reset method must be implemented by {self.__class__.__name__}")

optimum/neuron/models/inference/backend/pretrained_model.py

Lines changed: 138 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,23 @@
1818
from abc import ABC, abstractmethod
1919
from functools import partial
2020
from pathlib import Path
21+
from tempfile import TemporaryDirectory
2122

2223
import neuronx_distributed.trace.hlo_utils as hlo_utils
2324
import torch
2425
from huggingface_hub import HfApi, snapshot_download
2526
from neuronx_distributed.trace.model_builder import ModelBuilder
2627
from safetensors.torch import load_file
27-
from transformers import AutoModelForCausalLM, PretrainedConfig
28+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
2829

30+
from ....cache.entries.single_model import SingleModelCacheEntry
31+
from ....cache.hub_cache import hub_neuronx_cache
32+
from ....utils.instance import align_compilation_target, current_instance_type
33+
from ....utils.system import get_available_cores
2934
from ..modeling_utils import NeuronPreTrainedModel
3035
from .config import NxDNeuronConfig
3136
from .graph_builder import NxDGraphBuilder
32-
from .modules.checkpoint import (
33-
load_state_dict,
34-
)
37+
from .modules.checkpoint import load_state_dict
3538

3639

3740
logger = logging.getLogger("Neuron")
@@ -305,11 +308,138 @@ def device(self) -> torch.device:
305308
# We dont want HF to move parameters to device
306309
return torch.device("cpu")
307310

308-
def reset(self):
309-
"""Resets the model state. Can be implemented by subclasses."""
310-
pass
311-
312311
# NeuronPreTrainedModel methods
312+
@classmethod
313+
def _export(
314+
cls,
315+
model_id: str,
316+
config: "PretrainedConfig | None",
317+
neuron_config: "NxDNeuronConfig",
318+
token: bool | str | None = None,
319+
revision: str | None = None,
320+
cache_dir: str | None = None,
321+
force_download: bool | None = False,
322+
local_files_only: bool | None = False,
323+
trust_remote_code: bool | None = False,
324+
load_weights: bool | None = False,
325+
**kwargs,
326+
) -> NeuronPreTrainedModel:
327+
if len(kwargs) > 0:
328+
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
329+
# Try to align compilation target. We do not allow override as neuronx-distributed is already initialized.
330+
compilation_target = align_compilation_target(neuron_config.target, override=False)
331+
if compilation_target != neuron_config.target:
332+
raise ValueError(
333+
f"The compilation target is {neuron_config.target} but the NEURON_PLATFORM_TARGET_OVERRIDE"
334+
f" environment variable is set to {compilation_target}, Please set it to the correct value."
335+
)
336+
if config is None:
337+
# Get the text config if not provided
338+
config = AutoConfig.from_pretrained(
339+
model_id,
340+
token=token,
341+
revision=revision,
342+
cache_dir=cache_dir,
343+
force_download=force_download,
344+
trust_remote_code=trust_remote_code,
345+
).get_text_config()
346+
# Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
347+
config.torch_dtype = neuron_config.torch_dtype
348+
# Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+)
349+
if hasattr(config, "head_dim") and config.head_dim is None:
350+
config.head_dim = config.hidden_size // config.num_attention_heads
351+
graph_builders = cls.create_graph_builders(
352+
config=config,
353+
neuron_config=neuron_config,
354+
)
355+
# The model NEFF files will be cached locally, but if the model_id corresponds
356+
# to a hub model, we also create a cache entry for it.
357+
cache_entry = (
358+
None
359+
if os.path.exists(model_id)
360+
else SingleModelCacheEntry(model_id, task="text-generation", config=config, neuron_config=neuron_config)
361+
)
362+
with hub_neuronx_cache(entry=cache_entry):
363+
traced_model = NxDPreTrainedModel.compile(
364+
neuron_config=neuron_config,
365+
graph_builders=graph_builders,
366+
compiler_args=cls.get_compiler_args(neuron_config),
367+
)
368+
model = cls(
369+
config=config,
370+
neuron_config=neuron_config,
371+
traced_model=traced_model,
372+
graph_builders=graph_builders,
373+
)
374+
if load_weights:
375+
model.load_weights(
376+
model_id,
377+
cache_dir=cache_dir,
378+
force_download=force_download,
379+
local_files_only=local_files_only,
380+
token=token,
381+
)
382+
return model
383+
384+
@classmethod
385+
def _from_pretrained(
386+
cls,
387+
model_id: "str | Path",
388+
config: "PretrainedConfig",
389+
revision: str | None = None,
390+
token: bool | str | None = None,
391+
cache_dir: str | None = None,
392+
force_download: bool | None = False,
393+
local_files_only: bool | None = False,
394+
**kwargs,
395+
) -> NeuronPreTrainedModel:
396+
if len(kwargs) > 0:
397+
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
398+
neuron_config = NxDNeuronConfig.from_pretrained(model_id)
399+
# Check the current instance type is compatible with the one used to compile the model
400+
if neuron_config.target != current_instance_type():
401+
raise ValueError(
402+
f"The model was compiled for {neuron_config.target} but the current instance type is "
403+
f"{current_instance_type()}. Please use a compatible instance type."
404+
)
405+
# Also check the number of cores is at least equal to the tensor parallel size
406+
if get_available_cores() < neuron_config.tp_degree:
407+
raise ValueError(
408+
f"The model requires at least {neuron_config.tp_degree} Neuron cores but only "
409+
f"{get_available_cores()} are available. Please use a compatible instance type."
410+
)
411+
if not os.path.exists(model_id):
412+
# The model_id is a model hub id: download the model from the hub.
413+
with TemporaryDirectory() as tmpdir:
414+
snapshot_download(
415+
repo_id=model_id,
416+
revision=revision,
417+
cache_dir=cache_dir,
418+
local_dir=tmpdir,
419+
force_download=force_download,
420+
local_files_only=local_files_only,
421+
token=token,
422+
allow_patterns=[cls.COMPILED_MODEL_FILE_NAME],
423+
)
424+
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
425+
else:
426+
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
427+
graph_builders = cls.create_graph_builders(config=config, neuron_config=neuron_config)
428+
model = cls(
429+
config=config,
430+
neuron_config=neuron_config,
431+
traced_model=traced_model,
432+
graph_builders=graph_builders,
433+
)
434+
model.load_weights(
435+
model_id,
436+
cache_dir=cache_dir,
437+
force_download=force_download,
438+
local_files_only=local_files_only,
439+
token=token,
440+
)
441+
return model
442+
313443
def _save_pretrained(self, save_directory: str | Path, **kwargs):
314444
model_name_or_path = getattr(self.config, "_name_or_path")
315445
# If the model was exported from a local path, we need to save the checkpoint (not that we also shard it)

0 commit comments

Comments
 (0)