|
18 | 18 | from abc import ABC, abstractmethod |
19 | 19 | from functools import partial |
20 | 20 | from pathlib import Path |
| 21 | +from tempfile import TemporaryDirectory |
21 | 22 |
|
22 | 23 | import neuronx_distributed.trace.hlo_utils as hlo_utils |
23 | 24 | import torch |
24 | 25 | from huggingface_hub import HfApi, snapshot_download |
25 | 26 | from neuronx_distributed.trace.model_builder import ModelBuilder |
26 | 27 | from safetensors.torch import load_file |
27 | | -from transformers import AutoModelForCausalLM, PretrainedConfig |
| 28 | +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig |
28 | 29 |
|
| 30 | +from ....cache.entries.single_model import SingleModelCacheEntry |
| 31 | +from ....cache.hub_cache import hub_neuronx_cache |
| 32 | +from ....utils.instance import align_compilation_target, current_instance_type |
| 33 | +from ....utils.system import get_available_cores |
29 | 34 | from ..modeling_utils import NeuronPreTrainedModel |
30 | 35 | from .config import NxDNeuronConfig |
31 | 36 | from .graph_builder import NxDGraphBuilder |
32 | | -from .modules.checkpoint import ( |
33 | | - load_state_dict, |
34 | | -) |
| 37 | +from .modules.checkpoint import load_state_dict |
35 | 38 |
|
36 | 39 |
|
37 | 40 | logger = logging.getLogger("Neuron") |
@@ -305,11 +308,138 @@ def device(self) -> torch.device: |
305 | 308 | # We dont want HF to move parameters to device |
306 | 309 | return torch.device("cpu") |
307 | 310 |
|
308 | | - def reset(self): |
309 | | - """Resets the model state. Can be implemented by subclasses.""" |
310 | | - pass |
311 | | - |
312 | 311 | # NeuronPreTrainedModel methods |
| 312 | + @classmethod |
| 313 | + def _export( |
| 314 | + cls, |
| 315 | + model_id: str, |
| 316 | + config: "PretrainedConfig | None", |
| 317 | + neuron_config: "NxDNeuronConfig", |
| 318 | + token: bool | str | None = None, |
| 319 | + revision: str | None = None, |
| 320 | + cache_dir: str | None = None, |
| 321 | + force_download: bool | None = False, |
| 322 | + local_files_only: bool | None = False, |
| 323 | + trust_remote_code: bool | None = False, |
| 324 | + load_weights: bool | None = False, |
| 325 | + **kwargs, |
| 326 | + ) -> NeuronPreTrainedModel: |
| 327 | + if len(kwargs) > 0: |
| 328 | + logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys()) |
| 329 | + # Try to align compilation target. We do not allow override as neuronx-distributed is already initialized. |
| 330 | + compilation_target = align_compilation_target(neuron_config.target, override=False) |
| 331 | + if compilation_target != neuron_config.target: |
| 332 | + raise ValueError( |
| 333 | + f"The compilation target is {neuron_config.target} but the NEURON_PLATFORM_TARGET_OVERRIDE" |
| 334 | + f" environment variable is set to {compilation_target}, Please set it to the correct value." |
| 335 | + ) |
| 336 | + if config is None: |
| 337 | + # Get the text config if not provided |
| 338 | + config = AutoConfig.from_pretrained( |
| 339 | + model_id, |
| 340 | + token=token, |
| 341 | + revision=revision, |
| 342 | + cache_dir=cache_dir, |
| 343 | + force_download=force_download, |
| 344 | + trust_remote_code=trust_remote_code, |
| 345 | + ).get_text_config() |
| 346 | + # Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type |
| 347 | + config.torch_dtype = neuron_config.torch_dtype |
| 348 | + # Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+) |
| 349 | + if hasattr(config, "head_dim") and config.head_dim is None: |
| 350 | + config.head_dim = config.hidden_size // config.num_attention_heads |
| 351 | + graph_builders = cls.create_graph_builders( |
| 352 | + config=config, |
| 353 | + neuron_config=neuron_config, |
| 354 | + ) |
| 355 | + # The model NEFF files will be cached locally, but if the model_id corresponds |
| 356 | + # to a hub model, we also create a cache entry for it. |
| 357 | + cache_entry = ( |
| 358 | + None |
| 359 | + if os.path.exists(model_id) |
| 360 | + else SingleModelCacheEntry(model_id, task="text-generation", config=config, neuron_config=neuron_config) |
| 361 | + ) |
| 362 | + with hub_neuronx_cache(entry=cache_entry): |
| 363 | + traced_model = NxDPreTrainedModel.compile( |
| 364 | + neuron_config=neuron_config, |
| 365 | + graph_builders=graph_builders, |
| 366 | + compiler_args=cls.get_compiler_args(neuron_config), |
| 367 | + ) |
| 368 | + model = cls( |
| 369 | + config=config, |
| 370 | + neuron_config=neuron_config, |
| 371 | + traced_model=traced_model, |
| 372 | + graph_builders=graph_builders, |
| 373 | + ) |
| 374 | + if load_weights: |
| 375 | + model.load_weights( |
| 376 | + model_id, |
| 377 | + cache_dir=cache_dir, |
| 378 | + force_download=force_download, |
| 379 | + local_files_only=local_files_only, |
| 380 | + token=token, |
| 381 | + ) |
| 382 | + return model |
| 383 | + |
| 384 | + @classmethod |
| 385 | + def _from_pretrained( |
| 386 | + cls, |
| 387 | + model_id: "str | Path", |
| 388 | + config: "PretrainedConfig", |
| 389 | + revision: str | None = None, |
| 390 | + token: bool | str | None = None, |
| 391 | + cache_dir: str | None = None, |
| 392 | + force_download: bool | None = False, |
| 393 | + local_files_only: bool | None = False, |
| 394 | + **kwargs, |
| 395 | + ) -> NeuronPreTrainedModel: |
| 396 | + if len(kwargs) > 0: |
| 397 | + logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys()) |
| 398 | + neuron_config = NxDNeuronConfig.from_pretrained(model_id) |
| 399 | + # Check the current instance type is compatible with the one used to compile the model |
| 400 | + if neuron_config.target != current_instance_type(): |
| 401 | + raise ValueError( |
| 402 | + f"The model was compiled for {neuron_config.target} but the current instance type is " |
| 403 | + f"{current_instance_type()}. Please use a compatible instance type." |
| 404 | + ) |
| 405 | + # Also check the number of cores is at least equal to the tensor parallel size |
| 406 | + if get_available_cores() < neuron_config.tp_degree: |
| 407 | + raise ValueError( |
| 408 | + f"The model requires at least {neuron_config.tp_degree} Neuron cores but only " |
| 409 | + f"{get_available_cores()} are available. Please use a compatible instance type." |
| 410 | + ) |
| 411 | + if not os.path.exists(model_id): |
| 412 | + # The model_id is a model hub id: download the model from the hub. |
| 413 | + with TemporaryDirectory() as tmpdir: |
| 414 | + snapshot_download( |
| 415 | + repo_id=model_id, |
| 416 | + revision=revision, |
| 417 | + cache_dir=cache_dir, |
| 418 | + local_dir=tmpdir, |
| 419 | + force_download=force_download, |
| 420 | + local_files_only=local_files_only, |
| 421 | + token=token, |
| 422 | + allow_patterns=[cls.COMPILED_MODEL_FILE_NAME], |
| 423 | + ) |
| 424 | + traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME)) |
| 425 | + else: |
| 426 | + traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME)) |
| 427 | + graph_builders = cls.create_graph_builders(config=config, neuron_config=neuron_config) |
| 428 | + model = cls( |
| 429 | + config=config, |
| 430 | + neuron_config=neuron_config, |
| 431 | + traced_model=traced_model, |
| 432 | + graph_builders=graph_builders, |
| 433 | + ) |
| 434 | + model.load_weights( |
| 435 | + model_id, |
| 436 | + cache_dir=cache_dir, |
| 437 | + force_download=force_download, |
| 438 | + local_files_only=local_files_only, |
| 439 | + token=token, |
| 440 | + ) |
| 441 | + return model |
| 442 | + |
313 | 443 | def _save_pretrained(self, save_directory: str | Path, **kwargs): |
314 | 444 | model_name_or_path = getattr(self.config, "_name_or_path") |
315 | 445 | # If the model was exported from a local path, we need to save the checkpoint (not that we also shard it) |
|
0 commit comments