diff --git a/exir/tests/control_flow_models.py b/exir/tests/control_flow_models.py index 0815f8fbc5b..5aab85cc45a 100644 --- a/exir/tests/control_flow_models.py +++ b/exir/tests/control_flow_models.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +from torch._higher_order_ops.map import map as torch_map from torch.nn import Module # @manual @@ -87,7 +88,7 @@ def forward(self, xs, y): def f(x, y): return x + y - return torch.ops.higher_order.map(f, xs, y) + xs + return torch_map(f, xs, y) + xs def get_random_inputs(self): return torch.rand(2, 4), torch.rand(4) @@ -101,7 +102,7 @@ def forward(self, xs, y): def f(x, y): return x + y - return torch.ops.higher_order.map(f, xs, y) + xs + return torch_map(f, xs, y) + xs def get_upper_bound_inputs(self): return torch.rand(4, 4), torch.rand(4)