Skip to content

Commit cecbb58

Browse files
junjiang-labcopybara-github
authored andcommitted
Add lowering for aten.unfold.
PiperOrigin-RevId: 822295986
1 parent fbf4ff3 commit cecbb58

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

ai_edge_torch/odml_torch/lowerings/_basic.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
325325
return stablehlo.concatenate(non_empty_tensors, dim)
326326

327327

328+
# Schema:
329+
# - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
330+
# Torch Reference:
331+
# - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
332+
@lower(torch.ops.aten.unfold.default)
333+
def _aten_unfold(lctx, x: ir.Value, dim: int, size: int, step: int):
334+
x_shape = x.type.shape
335+
rank = len(x_shape)
336+
if dim < 0:
337+
dim += rank
338+
339+
num_windows = (x_shape[dim] - size) // step + 1
340+
batch_shape = list(x_shape[:dim]) + [num_windows] + list(x_shape[dim + 1 :])
341+
342+
# Create start_indices for gather.
343+
# The shape of start_indices will be batch_shape + [rank].
344+
# start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
345+
# p_j = b_j for j != dim and p_dim = b_dim * step.
346+
indices_parts = []
347+
i64 = ir.IntegerType.get_signless(64)
348+
for i in range(rank):
349+
bshape = [1] * rank
350+
bshape[i] = batch_shape[i]
351+
dim_len = batch_shape[i]
352+
353+
iota = stablehlo.IotaOp(
354+
ir.RankedTensorType.get([dim_len], i64),
355+
iota_dimension=ir.IntegerAttr.get(i64, 0),
356+
).result
357+
if i == dim:
358+
iota = stablehlo.multiply(iota, utils.splat(step, i64, [dim_len]))
359+
360+
iota_reshaped = stablehlo.reshape(
361+
ir.RankedTensorType.get(bshape, i64), iota
362+
)
363+
indices_parts.append(
364+
stablehlo.broadcast_in_dim(
365+
ir.RankedTensorType.get(batch_shape, i64),
366+
iota_reshaped,
367+
ir.DenseI64ArrayAttr.get(list(range(rank))),
368+
)
369+
)
370+
371+
# For each dimension i, indices_parts[i] contains the i-th coordinate
372+
# of start_indices. We unsqueeze each part to shape batch_shape + [1]
373+
# and concatenate along the new dimension to produce start_indices of
374+
# shape batch_shape + [rank].
375+
unsqueezed_parts = [
376+
stablehlo.reshape(ir.RankedTensorType.get(batch_shape + [1], i64), part)
377+
for part in indices_parts
378+
]
379+
start_indices = stablehlo.concatenate(
380+
unsqueezed_parts, ir.IntegerAttr.get(i64, rank)
381+
)
382+
383+
slice_sizes_list = [1] * rank
384+
slice_sizes_list[dim] = size
385+
slice_sizes = ir.DenseI64ArrayAttr.get(slice_sizes_list)
386+
387+
collapsed_slice_dims_list = [i for i in range(rank) if i != dim]
388+
389+
dnums = stablehlo.GatherDimensionNumbers.get(
390+
offset_dims=[rank],
391+
collapsed_slice_dims=collapsed_slice_dims_list,
392+
operand_batching_dims=[],
393+
start_indices_batching_dims=[],
394+
start_index_map=list(range(rank)),
395+
index_vector_dim=rank,
396+
)
397+
398+
return stablehlo.gather(
399+
x,
400+
start_indices,
401+
dnums,
402+
slice_sizes,
403+
indices_are_sorted=ir.BoolAttr.get(False),
404+
)
405+
406+
328407
# Schema:
329408
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
330409
# start=None, SymInt? end=None, SymInt step=1) -> Tensor

ai_edge_torch/odml_torch/test/test_core_aten_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ def _run_export_and_compare(
393393
("aten_tril_0", torch.ops.aten.tril, (rnd(torch.float32, (10, 10)),), dict()),
394394
("aten_trunc_0", torch.ops.aten.trunc, (rnd(torch.float32, (10, 10)),), dict()),
395395
("aten_unbind_copy_int_0", torch.ops.aten.unbind_copy.int, (rnd(torch.float32, (10, 10)),), dict()),
396+
("aten_unfold_copy_0", torch.ops.aten.unfold, (rnd(torch.float32, (4, 4)), 0, 2, 1), dict()),
397+
("aten_unfold_copy_1", torch.ops.aten.unfold, (rnd(torch.float32, (4, 4)), 1, 2, 1), dict()),
398+
("aten_unfold_copy_2", torch.ops.aten.unfold, (rnd(torch.float32, (5, 5)), 1, 1, 2), dict()),
399+
("aten_unfold_copy_3", torch.ops.aten.unfold, (rnd(torch.float32, (2, 4, 4)), -1, 2, 1), dict()),
396400
("aten_unsqueeze_copy_0", torch.ops.aten.unsqueeze_copy, (rnd(torch.float32, (2, 0, 2)), 1,), dict()),
397401
("aten_upsample_bilinear2d_0", torch.ops.aten.upsample_bilinear2d, (rnd(torch.float32, (1, 3, 2, 10)), [3, 20], False,), dict()),
398402
("aten_upsample_nearest2d_0", torch.ops.aten.upsample_nearest2d, (rnd(torch.float32, (1, 3, 2, 10)), [3, 20],), dict()),

0 commit comments

Comments
 (0)