We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d1086d6 commit e26d0e1Copy full SHA for e26d0e1
graph_net/test/bert_model_test.py
@@ -24,7 +24,7 @@ def create_model():
24
inputs = {k: v.to(device) for k, v in inputs.items()}
25
26
model = create_model()
27
- model = graph_net.torch.extract(name=get_model_name())(model)
+ model = graph_net.torch.extract(name=get_model_name(), dynamic=True)(model)
28
29
print("Running inference...")
30
output = model(**inputs)
graph_net/test/vision_model_test.py
@@ -28,7 +28,7 @@
model.to(device)
normalized_input = normalized_input.to(device)
31
- model = graph_net.torch.extract(name="resnet18")(model)
+ model = graph_net.torch.extract(name="resnet18", dynamic=True)(model)
32
33
34
print("Input shape:", normalized_input.shape)
0 commit comments