@@ -325,6 +325,85 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
325325 return stablehlo .concatenate (non_empty_tensors , dim )
326326
327327
328+ # Schema:
329+ # - aten::unfold(Tensor self, int dim, int size, int step) -> Tensor
330+ # Torch Reference:
331+ # - https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
332+ @lower (torch .ops .aten .unfold .default )
333+ def _aten_unfold (lctx , x : ir .Value , dim : int , size : int , step : int ):
334+ x_shape = x .type .shape
335+ rank = len (x_shape )
336+ if dim < 0 :
337+ dim += rank
338+
339+ num_windows = (x_shape [dim ] - size ) // step + 1
340+ batch_shape = list (x_shape [:dim ]) + [num_windows ] + list (x_shape [dim + 1 :])
341+
342+ # Create start_indices for gather.
343+ # The shape of start_indices will be batch_shape + [rank].
344+ # start_indices[b_0,...,b_{rank-1}] will be [p_0,...,p_{rank-1}] where
345+ # p_j = b_j for j != dim and p_dim = b_dim * step.
346+ indices_parts = []
347+ i64 = ir .IntegerType .get_signless (64 )
348+ for i in range (rank ):
349+ bshape = [1 ] * rank
350+ bshape [i ] = batch_shape [i ]
351+ dim_len = batch_shape [i ]
352+
353+ iota = stablehlo .IotaOp (
354+ ir .RankedTensorType .get ([dim_len ], i64 ),
355+ iota_dimension = ir .IntegerAttr .get (i64 , 0 ),
356+ ).result
357+ if i == dim :
358+ iota = stablehlo .multiply (iota , utils .splat (step , i64 , [dim_len ]))
359+
360+ iota_reshaped = stablehlo .reshape (
361+ ir .RankedTensorType .get (bshape , i64 ), iota
362+ )
363+ indices_parts .append (
364+ stablehlo .broadcast_in_dim (
365+ ir .RankedTensorType .get (batch_shape , i64 ),
366+ iota_reshaped ,
367+ ir .DenseI64ArrayAttr .get (list (range (rank ))),
368+ )
369+ )
370+
371+ # For each dimension i, indices_parts[i] contains the i-th coordinate
372+ # of start_indices. We unsqueeze each part to shape batch_shape + [1]
373+ # and concatenate along the new dimension to produce start_indices of
374+ # shape batch_shape + [rank].
375+ unsqueezed_parts = [
376+ stablehlo .reshape (ir .RankedTensorType .get (batch_shape + [1 ], i64 ), part )
377+ for part in indices_parts
378+ ]
379+ start_indices = stablehlo .concatenate (
380+ unsqueezed_parts , ir .IntegerAttr .get (i64 , rank )
381+ )
382+
383+ slice_sizes_list = [1 ] * rank
384+ slice_sizes_list [dim ] = size
385+ slice_sizes = ir .DenseI64ArrayAttr .get (slice_sizes_list )
386+
387+ collapsed_slice_dims_list = [i for i in range (rank ) if i != dim ]
388+
389+ dnums = stablehlo .GatherDimensionNumbers .get (
390+ offset_dims = [rank ],
391+ collapsed_slice_dims = collapsed_slice_dims_list ,
392+ operand_batching_dims = [],
393+ start_indices_batching_dims = [],
394+ start_index_map = list (range (rank )),
395+ index_vector_dim = rank ,
396+ )
397+
398+ return stablehlo .gather (
399+ x ,
400+ start_indices ,
401+ dnums ,
402+ slice_sizes ,
403+ indices_are_sorted = ir .BoolAttr .get (False ),
404+ )
405+
406+
328407# Schema:
329408# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
330409# start=None, SymInt? end=None, SymInt step=1) -> Tensor
0 commit comments