discover_models.py tests which torchvision models are FX-traceable and can be used with the graph characterization tools. It scans the torchvision model zoo and generates a registry of compatible models.
Key Capabilities:
- Test FX-traceability of all torchvision models
- Filter by model family (ResNet, MobileNet, ViT, etc.)
- Generate MODEL_REGISTRY code for other tools
- Test individual models
- Custom skip patterns
Target Users:
- New users discovering available models
- Tool developers maintaining model registries
- Engineers debugging tracing issues
Requirements:
pip install torch torchvisionVerify Installation:
python3 cli/discover_models.py --helppython3 cli/discover_models.pyOutput:
- Summary statistics
- FX-traceable models grouped by family
- Skipped models (detection, segmentation, video)
See each model test result:
python3 cli/discover_models.py --verboseOutput:
Testing models:
✓ resnet18 - FX-traceable
✓ resnet34 - FX-traceable
✓ resnet50 - FX-traceable
✗ fasterrcnn_resnet50_fpn - TypeError: forward() takes 2 positional...
✓ mobilenet_v2 - FX-traceable
✓ vit_b_16 - FX-traceable
...
Export Python code for MODEL_REGISTRY:
python3 cli/discover_models.py --generate-codeOutput:
MODEL_REGISTRY = {
# Resnet family
'resnet18': models.resnet18,
'resnet34': models.resnet34,
'resnet50': models.resnet50,
# Mobilenet family
'mobilenet_v2': models.mobilenet_v2,
'mobilenet_v3_small': models.mobilenet_v3_small,
# Vit family
'vit_b_16': models.vit_b_16,
'vit_l_16': models.vit_l_16,
...
}Use Case: Copy-paste into profile_graph.py or custom analysis scripts
Test a specific model:
python3 cli/discover_models.py --test-model resnet18Output:
Testing resnet18...
✓ resnet18 - FX-traceable
✓ resnet18 is FX-traceable!
| Argument | Type | Description |
|---|---|---|
--verbose, -v |
flag | Show detailed test results for each model |
--generate-code, -g |
flag | Generate MODEL_REGISTRY Python code |
--test-model |
str | Test a single specific model |
--skip-patterns |
str[] | Space-separated patterns to skip |
By default, the following model categories are skipped:
Detection Models:
rcnn(Faster R-CNN, Mask R-CNN)retinanetfcosssd
Segmentation Models:
deeplabv3fcnlraspp
Video Models (5D input):
raft(optical flow)r3d,r2plus1d,mc3,s3dmvit,swin3d
Quantized Models:
quantized_*
Reason: These models require special inputs (multiple tensors, 5D inputs, etc.) that don't work with standard FX symbolic tracing.
Override default skip patterns:
python3 cli/discover_models.py --skip-patterns fcos vit ssdUse Case: Test specific model families that were previously skipped
================================================================================
SUMMARY
================================================================================
FX-traceable: 147 models
Failed: 23 models
Skipped: 85 models (detection/segmentation/video/quantized)
================================================================================
FX-TRACEABLE MODELS BY FAMILY
================================================================================
RESNET (8):
resnet18
resnet34
resnet50
resnet101
resnet152
resnext50_32x4d
resnext101_32x8d
wide_resnet50_2
MOBILENET (5):
mobilenet_v2
mobilenet_v3_large
mobilenet_v3_small
EFFICIENTNET (8):
efficientnet_b0
efficientnet_b1
efficientnet_b2
efficientnet_b3
efficientnet_b4
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_v2_s
efficientnet_v2_m
efficientnet_v2_l
VIT (10):
vit_b_16
vit_b_32
vit_l_16
vit_l_32
vit_h_14
CONVNEXT (6):
convnext_tiny
convnext_small
convnext_base
convnext_large
VGG (4):
vgg11
vgg13
vgg16
vgg19
vgg11_bn
vgg13_bn
vgg16_bn
vgg19_bn
DENSENET (4):
densenet121
densenet161
densenet169
densenet201
And more...
Before using a model in analysis, verify it's FX-traceable:
python3 cli/discover_models.py --test-model efficientnet_b0python3 cli/discover_models.py --verbose | grep "vit_"After a torchvision update, regenerate the registry:
python3 cli/discover_models.py --generate-code > new_registry.txt
# Review and copy to profile_graph.pySee why a model fails:
python3 cli/discover_models.py --test-model fasterrcnn_resnet50_fpn --verboseOutput shows error message:
✗ fasterrcnn_resnet50_fpn - TypeError: forward() takes 2 positional...
# Skip everything except ConvNeXt
python3 cli/discover_models.py \
--skip-patterns resnet mobile efficient vit vgg dense shuffle squeeze \
--verbose | grep convnextPyTorch FX (torch.fx) performs symbolic tracing of a model:
- Records operations as a graph
- Propagates tensor shapes
- Enables graph transformations
Requirement for our tools: Models must be FX-traceable to use partitioning and hardware mapping.
Common Failure Reasons:
-
Dynamic Control Flow
# NOT FX-traceable if x.shape[0] > 10: return self.path_a(x) else: return self.path_b(x)
-
Multiple Inputs/Outputs
# Detection models need (images, targets) def forward(self, images, targets=None): ...
-
Python Built-in Types
# Using Python lists/dicts in forward() def forward(self, x): results = [] for layer in self.layers: results.append(layer(x)) return results
-
Non-Tensor Operations
# In-place modifications, assertions, etc. assert x.shape[0] == 1
✓ Single forward() signature
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)✓ Static control flow
# All branches traced
x = self.conv1(x) if self.use_conv1 else x # OK if traced with same value✓ Standard PyTorch operations
x = torch.relu(x)
x = self.conv(x)- ResNet family (resnet18, resnet50, etc.)
- MobileNet family (mobilenet_v2, mobilenet_v3_*)
- EfficientNet family (efficientnet_b0-b7, efficientnet_v2_*)
- VGG family (vgg16, vgg19)
- DenseNet family (densenet121, densenet161, etc.)
- Vision Transformers (vit_b_16, vit_l_16, etc.)
- ConvNeXt (convnext_tiny, convnext_small, etc.)
- SqueezeNet, ShuffleNet, etc.
Total: 140+ models
- Faster R-CNN, Mask R-CNN
- RetinaNet
- FCOS
- SSD
Reason: Require (images, targets) input during training
- DeepLabV3
- FCN
- LRASPP
Reason: Complex multi-scale architectures with dynamic shapes
- R3D, R(2+1)D
- MViT, Swin3D
Reason: Require 5D input (batch, channels, time, height, width)
Solution:
pip install torchvisionCheck torchvision version:
python3 -c "import torchvision; print(torchvision.__version__)"Recommended: torchvision >= 0.13.0
Update:
pip install --upgrade torchvisionReason: Model might be:
- In a skipped category (detection, segmentation, video)
- Not yet in torchvision (check version)
- Actually failing FX tracing
Debug:
python3 cli/discover_models.py --test-model <model_name> --verbosepython3 cli/discover_models.py --skip-patterns "" --verboseWarning: This will attempt to trace detection/segmentation models (they will fail)
import subprocess
import json
# Run discovery
result = subprocess.run(
['python3', 'cli/discover_models.py'],
capture_output=True,
text=True
)
# Parse output for traceable models
lines = result.stdout.split('\n')
traceable = [
line.strip()
for line in lines
if line.startswith(' ') and not line.startswith(' ' * 3)
]Test your own models:
import torch
from torch.fx import symbolic_trace
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
model = MyModel()
try:
traced = symbolic_trace(model)
print("✓ Model is FX-traceable!")
except Exception as e:
print(f"✗ Model NOT traceable: {e}")After generating code:
python3 cli/discover_models.py --generate-code > registry_code.txtIntegrate into tools:
- Copy the
MODEL_REGISTRY = {...}block - Paste into
profile_graph.pyor your custom script - Use models by name:
from torchvision import models
MODEL_REGISTRY = {
# ... generated code ...
}
# Load model by name
model_fn = MODEL_REGISTRY['resnet50']
model = model_fn(weights=None)| Tool | Purpose |
|---|---|
profile_graph.py |
Profile discovered models |
analyze_graph_mapping.py |
Map models to hardware |
compare_models.py |
Compare models across hardware |
- PyTorch FX Documentation: https://pytorch.org/docs/stable/fx.html
- Symbolic Tracing:
experiments/fx/tutorial/ - Architecture Guide:
CLAUDE.md
Report issues or request features at the project repository.