Skip to content

Commit bdbd0d1

Browse files
committed
fix(cli): remove import to avoid importing neuronx_distributed
1 parent 35da928 commit bdbd0d1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

optimum/commands/neuron/subcommands.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from typing import TYPE_CHECKING
1818

19-
from ...neuron.models.training import consolidate_model_parallel_checkpoints_to_unified_checkpoint
2019
from ...utils import logging
2120
from ..base import BaseOptimumCLICommand
2221

@@ -53,6 +52,9 @@ def parse_args(parser: "ArgumentParser"):
5352
)
5453

5554
def run(self):
55+
# This is not on top otherwise it will make the CLI require neuronx_distributed
56+
from ...neuron.models.training import consolidate_model_parallel_checkpoints_to_unified_checkpoint
57+
5658
checkpoint_format = "safetensors" if self.args.format == "safetensors" else "pytorch"
5759
logger.info(f"Consolidating checkpoints from {self.args.checkpoint_dir} to the {checkpoint_format} format...")
5860
output_dir = self.args.output_dir

0 commit comments

Comments
 (0)