@@ -325,6 +325,53 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> !
325325
326326// -----
327327
328+ // CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic(
329+ // CHECK-SAME: %[[INPUT_F32:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> {
330+ // CHECK: %[[INPUT_F32_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_F32]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
331+ // CHECK: %[[NONE:.*]] = torch.constant.none
332+ // CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
333+ // CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[INPUT_F32_TENSOR]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32>
334+ // CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32>
335+ // CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32>
336+ // CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape
337+ // CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor<f32>
338+ // CHECK: %[[RESULT_F32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor<f32> -> !torch.vtensor<[],f32>
339+ // CHECK: return %[[RESULT_F32]] : !torch.vtensor<[],f32>
340+ // CHECK: }
341+ func.func @test_reduce_sum_empty_dims$basic (%arg0: !torch.vtensor <[2 ,3 ,4 ],f32 >) -> !torch.vtensor <[],f32 > {
342+ %dtype_none = torch.constant.none
343+ %keep_dims_false = torch.constant.bool false
344+ %all_dims_list = torch.prim.ListConstruct : () -> !torch.list <int >
345+ %sum_all_dims = torch.aten.sum.dim_IntList %arg0 , %all_dims_list , %keep_dims_false , %dtype_none : !torch.vtensor <[2 ,3 ,4 ],f32 >, !torch.list <int >, !torch.bool , !torch.none -> !torch.vtensor <[],f32 >
346+ return %sum_all_dims : !torch.vtensor <[],f32 >
347+ }
348+
349+ // -----
350+
351+ // CHECK-LABEL: func.func @test_reduce_sum_empty_dims_i8_to_i32$basic(
352+ // CHECK-SAME: %[[INPUT_I8:.*]]: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> {
353+ // CHECK: %[[INPUT_I8_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_I8]] : !torch.vtensor<[2,3,4],si8> -> tensor<2x3x4xi8>
354+ // CHECK: %[[DTYPE_I32:.*]] = torch.constant.int 3
355+ // CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
356+ // CHECK: %[[CAST_INPUT_TO_I32:.*]] = tosa.cast %[[INPUT_I8_TENSOR]] : (tensor<2x3x4xi8>) -> tensor<2x3x4xi32>
357+ // CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[CAST_INPUT_TO_I32]] {axis = 0 : i32} : (tensor<2x3x4xi32>) -> tensor<1x3x4xi32>
358+ // CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xi32>) -> tensor<1x1x4xi32>
359+ // CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xi32>) -> tensor<1x1x1xi32>
360+ // CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape
361+ // CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xi32>, !tosa.shape<0>) -> tensor<i32>
362+ // CHECK: %[[RESULT_I32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor<i32> -> !torch.vtensor<[],si32>
363+ // CHECK: return %[[RESULT_I32]] : !torch.vtensor<[],si32>
364+ // CHECK: }
365+ func.func @test_reduce_sum_empty_dims_i8_to_i32$basic (%arg0: !torch.vtensor <[2 ,3 ,4 ],si8 >) -> !torch.vtensor <[],si32 > {
366+ %dtype_i32 = torch.constant.int 3
367+ %keep_dims_false = torch.constant.bool false
368+ %all_dims_list = torch.prim.ListConstruct : () -> !torch.list <int >
369+ %sum_all_dims_to_i32 = torch.aten.sum.dim_IntList %arg0 , %all_dims_list , %keep_dims_false , %dtype_i32 : !torch.vtensor <[2 ,3 ,4 ],si8 >, !torch.list <int >, !torch.bool , !torch.int -> !torch.vtensor <[],si32 >
370+ return %sum_all_dims_to_i32 : !torch.vtensor <[],si32 >
371+ }
372+
373+ // -----
374+
328375// CHECK-LABEL: func.func @test_linalg_vector_norm$basic(
329376// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> {
330377// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>
0 commit comments