-
Notifications
You must be signed in to change notification settings - Fork 193
Open
Labels
Description
Hi there,
I have successfully used nvidia-modelopt v0.29.0 with torch==2.0.0+cu12 to prune YOLOv8 models.
Recently, while trying to prune on one of my new environments with torch==2.9.0+cu13 and nvidia-modelopt v0.37.0, I encountered a bug while parsing a model and a prune checkpoint.
The trace is as follows
...
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/ultralytics/engine/trainer.py", line 928, in _setup_pruning
[rank0]: self.model, prune_res = mtp.prune(
[rank0]: ^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/prune/pruning.py", line 203, in prune
[rank0]: model = apply_mode(model, mode, registry=PruneModeRegistry)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/opt/conversion.py", line 418, in apply_mode
[rank0]: model, metadata = get_mode(m).convert(model, config, **kwargs) # type: ignore [call-arg]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/prune/fastnas.py", line 70, in convert_fastnas_searchspace
[rank0]: return convert_searchspace(model, config, FastNASPatchManager)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/nas/autonas.py", line 596, in convert_searchspace
[rank0]: search_space = generate_search_space(model, rules=config.model_dump())
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/nas/search_space.py", line 257, in generate_search_space
[rank0]: sym_map = analyze_symbols(model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 1356, in analyze_symbols
[rank0]: GraphDependencyProcessor(mod, graph_collection, sym_map).process()
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 482, in process
[rank0]: process_constraint(node, node_id, input_nodes)
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/modules/concat.py", line 294, in process [rank0]: tensors = self._get_root_nodes(tensors)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 319, in _get_root_nodes
[rank0]: return [self._dependency_map.root(n) for n in nodes]
[rank0]: ^^^^^
[rank0]: TypeError: 'Node' object is not iterable
Seems like the control flow breaks with NodeProcessor -> _synchronize_nodes -> _get_root_nodes when _get_root_nodes gets an input of type Node instead of expected list[Node] and a simple fix seems to be to add a check. Although this part of the code doesn't seem like the origin point..
def _get_root_nodes(self, nodes: list[Node]) -> list[Node]:
"""Return root nodes of nodes according to dependency map."""
# added this line as a fix
nodes = [nodes] if isinstance(nodes, Node) else nodes
return [self._dependency_map.root(n) for n in nodes]
To reiterate, the same model, checkpoint and parsing worked fine with v0.29.0
Any help/ guidance in resolving this or the expected behavior is appreciated.
System information
- Container used (if applicable): ?
- OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 24.04
- CPU architecture (x86_64, aarch64): x86_64
- GPU name (e.g. H100, A100, L40S): 5070Ti super
- GPU memory size: 16
- Number of GPUs: 2
- Library versions (if applicable):
- Python: ?
- ModelOpt version or commit hash: 0.37.0
- CUDA: 13.0
- PyTorch: 2.9.0+cu13
- Transformers: ?
- TensorRT-LLM: ?
- ONNXRuntime: ?
- TensorRT: ?
- Any other details that may help: ?