diff --git a/graph_net/test/bert_model_test.py b/graph_net/test/bert_model_test.py index 7aab7115e..c85a2d927 100644 --- a/graph_net/test/bert_model_test.py +++ b/graph_net/test/bert_model_test.py @@ -24,7 +24,7 @@ def create_model(): inputs = {k: v.to(device) for k, v in inputs.items()} model = create_model() - model = graph_net.torch.extract(name=get_model_name())(model) + model = graph_net.torch.extract(name=get_model_name(), dynamic=True)(model) print("Running inference...") output = model(**inputs) diff --git a/graph_net/test/vision_model_test.py b/graph_net/test/vision_model_test.py index ea51bc533..31100cb2f 100644 --- a/graph_net/test/vision_model_test.py +++ b/graph_net/test/vision_model_test.py @@ -28,7 +28,7 @@ model.to(device) normalized_input = normalized_input.to(device) - model = graph_net.torch.extract(name="resnet18")(model) + model = graph_net.torch.extract(name="resnet18", dynamic=True)(model) print("Running inference...") print("Input shape:", normalized_input.shape)