Skip to content

torch.trace's NodeProcessor fails to process single nodes when syncing #506

@aravi-miovision

Description

@aravi-miovision

Hi there,

I have successfully used nvidia-modelopt v0.29.0 with torch==2.0.0+cu12 to prune YOLOv8 models.
Recently, while trying to prune on one of my new environments with torch==2.9.0+cu13 and nvidia-modelopt v0.37.0, I encountered a bug while parsing a model and a prune checkpoint.

The trace is as follows

...                                                                                                                                                                                                                                                                                             
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/ultralytics/engine/trainer.py", line 928, in _setup_pruning                                                                                                                                                             
[rank0]:     self.model, prune_res = mtp.prune(                                                                                                                                                                                                                                                                                     
[rank0]:                             ^^^^^^^^^^                                                                                                                                                                                                                                                                                     
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/prune/pruning.py", line 203, in prune                                                                                                                                                                    
[rank0]:     model = apply_mode(model, mode, registry=PruneModeRegistry)                                                                                                                                                                                                                                                            
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                            
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/opt/conversion.py", line 418, in apply_mode                                                                                                                                                              
[rank0]:     model, metadata = get_mode(m).convert(model, config, **kwargs)  # type: ignore  [call-arg]                                                                                                                                                                                                                             
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                         
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/prune/fastnas.py", line 70, in convert_fastnas_searchspace                                                                                                                                               
[rank0]:     return convert_searchspace(model, config, FastNASPatchManager)                                                                                                                                                                                                                                                         
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                         
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/nas/autonas.py", line 596, in convert_searchspace                                                                                                                                                        
[rank0]:     search_space = generate_search_space(model, rules=config.model_dump())                                                                                                                                                                                                                                                 
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                 
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/nas/search_space.py", line 257, in generate_search_space                                                                                                                                                 
[rank0]:     sym_map = analyze_symbols(model)                                                                                                                                                                                                                                                                                       
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                       
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 1356, in analyze_symbols                                                                                                                                                        
[rank0]:     GraphDependencyProcessor(mod, graph_collection, sym_map).process()                                                                                                                                                                                                                                                     
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 482, in process                                                                                                                                                                 
[rank0]:     process_constraint(node, node_id, input_nodes)                                                                                                                                                                                                                                                                         
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/modules/concat.py", line 294, in process                                                                                                                                                           [rank0]:     tensors = self._get_root_nodes(tensors)                                                                                                                                      
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                      
[rank0]:   File "/home/user/miniconda/envs/ultralytics/lib/python3.12/site-packages/modelopt/torch/trace/analyzer.py", line 319, in _get_root_nodes               
[rank0]:     return [self._dependency_map.root(n) for n in nodes]                                                                                                                         
[rank0]:                                                   ^^^^^                                                                                                                          
[rank0]: TypeError: 'Node' object is not iterable

Seems like the control flow breaks with NodeProcessor -> _synchronize_nodes -> _get_root_nodes when _get_root_nodes gets an input of type Node instead of expected list[Node] and a simple fix seems to be to add a check. Although this part of the code doesn't seem like the origin point..

def _get_root_nodes(self, nodes: list[Node]) -> list[Node]:
        """Return root nodes of nodes according to dependency map."""
        # added this line as a fix
        nodes = [nodes] if isinstance(nodes, Node) else nodes
        return [self._dependency_map.root(n) for n in nodes]

To reiterate, the same model, checkpoint and parsing worked fine with v0.29.0

Any help/ guidance in resolving this or the expected behavior is appreciated.

System information

  • Container used (if applicable): ?
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 24.04
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): 5070Ti super
  • GPU memory size: 16
  • Number of GPUs: 2
  • Library versions (if applicable):
    • Python: ?
    • ModelOpt version or commit hash: 0.37.0
    • CUDA: 13.0
    • PyTorch: 2.9.0+cu13
    • Transformers: ?
    • TensorRT-LLM: ?
    • ONNXRuntime: ?
    • TensorRT: ?
  • Any other details that may help: ?

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions