@@ -254,3 +254,90 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
254254 %output = torch.aten.fake_quantize_per_channel_affine %input , %scale , %zero_point , %int1 , %int0 , %int255 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],si32 >, !torch.int , !torch.int , !torch.int -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
255255 return %output : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
256256}
257+
258+ // -----
259+
260+ // CHECK-LABEL: func.func @torch.aten.topk(
261+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,80],f32> {
262+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
263+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.topk") %[[T0]] {dim = -1 : i64, k = 80 : i64, largest = true, sorted = true, torch_operand_names = ["self"]} :
264+ // CHECK-SAME: tensor<?x2304xf32> -> tensor<?x80xf32>, tensor<?x80xi64>
265+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x80xf32> -> !torch.vtensor<[?,80],f32>
266+ // CHECK: return %[[RES]] : !torch.vtensor<[?,80],f32>
267+ func.func @torch.aten.topk (%input: !torch.vtensor <[?,2304 ],f32 >) -> !torch.vtensor <[?,80 ],f32 > {
268+ %int -1 = torch.constant.int -1
269+ %int80 = torch.constant.int 80
270+ %true = torch.constant.bool true
271+ %output0 , %output1 = torch.aten.topk %input , %int80 , %int -1 , %true , %true : !torch.vtensor <[?,2304 ],f32 >, !torch.int , !torch.int , !torch.bool , !torch.bool -> !torch.vtensor <[?,80 ],f32 >, !torch.vtensor <[?,80 ],si64 >
272+ return %output0 : !torch.vtensor <[?,80 ],f32 >
273+ }
274+
275+ // -----
276+
277+ // CHECK-LABEL: func.func @torch.aten.sort(
278+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
279+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
280+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.sort") %[[T0]] {descending = true, dim = -1 : i64, torch_operand_names = ["self"]} :
281+ // CHECK-SAME: tensor<?x2304xf32> -> tensor<?x2304xf32>, tensor<?x2304xi64>
282+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x2304xf32> -> !torch.vtensor<[?,2304],f32>
283+ // CHECK: return %[[RES]] : !torch.vtensor<[?,2304],f32>
284+ func.func @torch.aten.sort (%input: !torch.vtensor <[?,2304 ],f32 >) -> !torch.vtensor <[?,2304 ],f32 > {
285+ %int -1 = torch.constant.int -1
286+ %true = torch.constant.bool true
287+ %output0 , %output1 = torch.aten.sort %input , %int -1 , %true : !torch.vtensor <[?,2304 ],f32 >, !torch.int , !torch.bool -> !torch.vtensor <[?,2304 ],f32 >, !torch.vtensor <[?,2304 ],si64 >
288+ return %output0 : !torch.vtensor <[?,2304 ],f32 >
289+ }
290+
291+ // -----
292+
293+ // CHECK-LABEL: func.func @torch.aten.cumsum(
294+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
295+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?],si32> -> tensor<?xi32>
296+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.cumsum") %[[T0]] {dim = 0 : i64, torch_operand_names = ["self"]} : tensor<?xi32> -> tensor<?xi64>
297+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<?xi64> -> !torch.vtensor<[?],si64>
298+ // CHECK: return %[[RES]] : !torch.vtensor<[?],si64>
299+ func.func @torch.aten.cumsum (%input: !torch.vtensor <[?],si32 >) -> !torch.vtensor <[?],si64 > {
300+ %int0 = torch.constant.int 0
301+ %none = torch.constant.none
302+ %1 = torch.aten.cumsum %input , %int0 , %none : !torch.vtensor <[?],si32 >, !torch.int , !torch.none -> !torch.vtensor <[?],si64 >
303+ return %1 : !torch.vtensor <[?],si64 >
304+ }
305+
306+ // -----
307+
308+ // CHECK-LABEL: func.func @torch.aten.min.dim(
309+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
310+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,80],f32> -> tensor<?x80xf32>
311+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.min.dim") %[[T0]] {dim = 1 : i64, keepdim = false, torch_operand_names = ["self"]} :
312+ // CHECK-SAME: tensor<?x80xf32> -> tensor<?xf32>, tensor<?xi64>
313+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
314+ // CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
315+ func.func @torch.aten.min.dim (%input: !torch.vtensor <[?,80 ],f32 >) -> !torch.vtensor <[?],f32 > {
316+ %int1 = torch.constant.int 1
317+ %false = torch.constant.bool false
318+ %output0 , %output1 = torch.aten.min.dim %input , %int1 , %false : !torch.vtensor <[?,80 ],f32 >, !torch.int , !torch.bool -> !torch.vtensor <[?],f32 >, !torch.vtensor <[?],si64 >
319+ return %output0 : !torch.vtensor <[?],f32 >
320+ }
321+
322+ // -----
323+
324+ // CHECK-LABEL: func.func @torch.aten.view_dynamic_shape(
325+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,384,16],f32>, %[[ARG1:.*]]: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
326+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
327+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,384,16],f32> -> tensor<?x384x16xf32>
328+ // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x2736x16xf32>
329+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.view") %[[T0]], %[[DIM]] {size = array<i64: -9223372036854775808, 24, 16, 16>, torch_operand_names = ["self", "idx_0"]} :
330+ // CHECK-SAME: tensor<?x384x16xf32>, index -> tensor<?x24x16x16xf32>
331+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x24x16x16xf32> -> !torch.vtensor<[?,24,16,16],f32>
332+ // CHECK: return %[[RES]] : !torch.vtensor<[?,24,16,16],f32>
333+ func.func @torch.aten.view_dynamic_shape (%arg0: !torch.vtensor <[?,384 ,16 ],f32 >, %arg1: tensor <?x2736 x16 xf32 >) -> !torch.vtensor <[?,24 ,16 ,16 ],f32 > {
334+ %c0 = arith.constant 0 : index
335+ %int24 = torch.constant.int 24
336+ %int16 = torch.constant.int 16
337+ %dim_32 = tensor.dim %arg1 , %c0 : tensor <?x2736 x16 xf32 >
338+ %1 = arith.index_cast %dim_32 : index to i64
339+ %2 = torch_c.from_i64 %1
340+ %3 = torch.prim.ListConstruct %2 , %int24 , %int16 , %int16 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
341+ %4 = torch.aten.view %arg0 , %3 : !torch.vtensor <[?,384 ,16 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,24 ,16 ,16 ],f32 >
342+ return %4 : !torch.vtensor <[?,24 ,16 ,16 ],f32 >
343+ }
0 commit comments