Skip to content

Commit 1c1fc1a

Browse files
committed
lint
1 parent b6ebed8 commit 1c1fc1a

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4244,7 +4244,9 @@ def aten_index_put(
42444244
if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1:
42454245
# shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5)
42464246
# This case was only found in ops_data test.
4247-
return _aten_index_put_scatter_nd(self, [op.Reshape(indices[0], [-1])], values, accumulate)
4247+
return _aten_index_put_scatter_nd(
4248+
self, [op.Reshape(indices[0], [-1])], values, accumulate
4249+
)
42484250

42494251
def _make_reshape_list_broadcastable(reshape_list, values_shape):
42504252
# Remove ones until the rank of reshape_list matches values_shape.

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,10 @@ def _im2col_input_wrangler(
313313
def _index_put_input_wrangler(
314314
args: list[Any], kwargs: dict[str, Any]
315315
) -> tuple[list[Any], dict[str, Any]]:
316-
args[1] = [(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem)) for elem in args[1]]
316+
args[1] = [
317+
(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem))
318+
for elem in args[1]
319+
]
317320
return args, kwargs
318321

319322

0 commit comments

Comments
 (0)