@@ -4186,6 +4186,7 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
4186
4186
return %4 : !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
4187
4187
}
4188
4188
4189
+ // -----
4189
4190
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4190
4191
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4191
4192
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
@@ -4226,7 +4227,6 @@ func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,
4226
4227
}
4227
4228
4228
4229
// -----
4229
-
4230
4230
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4231
4231
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4232
4232
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
@@ -4264,3 +4264,115 @@ func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,
4264
4264
%3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
4265
4265
return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
4266
4266
}
4267
+
4268
+ // -----
4269
+ // CHECK-LABEL: func.func @torch.aten.mm$f32(
4270
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[1,22],f32>,
4271
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> {
4272
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[22,10],f32> -> tensor<22x10xf32>
4273
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[1,22],f32> -> tensor<1x22xf32>
4274
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 22]> : tensor<3xindex>} : () -> !tosa.shape<3>
4275
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<1x22xf32>, !tosa.shape<3>) -> tensor<1x1x22xf32>
4276
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 22, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4277
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<22x10xf32>, !tosa.shape<3>) -> tensor<1x22x10xf32>
4278
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4279
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4280
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x1x22xf32>, tensor<1x22x10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x10xf32>
4281
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2>
4282
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x1x10xf32>, !tosa.shape<2>) -> tensor<1x10xf32>
4283
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<1x10xf32> -> !torch.vtensor<[1,10],f32>
4284
+ // CHECK: return %[[RES]]
4285
+ func.func @torch.aten.mm$f32 (%arg0: !torch.vtensor <[1 ,22 ],f32 >, %arg1: !torch.vtensor <[22 ,10 ],f32 >) -> !torch.vtensor <[1 ,10 ],f32 > {
4286
+ %0 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],f32 >, !torch.vtensor <[22 ,10 ],f32 > -> !torch.vtensor <[1 ,10 ],f32 >
4287
+ return %0 : !torch.vtensor <[1 ,10 ],f32 >
4288
+ }
4289
+
4290
+ // -----
4291
+ // CHECK-LABEL: func.func @torch.aten.mm$si8
4292
+ // CHECK: tosa.matmul
4293
+ // CHECK-SAME: (tensor<1x1x22xi8>, tensor<1x22x10xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x10xi32>
4294
+ // CHECK-NOT: torch.aten.mm
4295
+ // CHECK: tosa.cast
4296
+ // CHECK-SAME: (tensor<1x10xi32>) -> tensor<1x10xi8>
4297
+ func.func @torch.aten.mm$si8 (%arg0: !torch.vtensor <[1 ,22 ],si8 >, %arg1: !torch.vtensor <[22 ,10 ],si8 >) -> !torch.vtensor <[1 ,10 ],si8 > {
4298
+ %0 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],si8 >, !torch.vtensor <[22 ,10 ],si8 > -> !torch.vtensor <[1 ,10 ],si8 >
4299
+ return %0 : !torch.vtensor <[1 ,10 ],si8 >
4300
+ }
4301
+
4302
+ // -----
4303
+ // CHECK-LABEL: func.func @torch.aten.mm$f16
4304
+ // CHECK: tosa.matmul
4305
+ // CHECK-SAME: (tensor<1x1x22xf16>, tensor<1x22x10xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x1x10xf32>
4306
+ // CHECK-NOT: torch.aten.mm
4307
+ // CHECK: tosa.cast
4308
+ // CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xf16>
4309
+ func.func @torch.aten.mm$f16 (%arg0: !torch.vtensor <[1 ,22 ],f16 >, %arg1: !torch.vtensor <[22 ,10 ],f16 >) -> !torch.vtensor <[1 ,10 ],f16 > {
4310
+ %4 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],f16 >, !torch.vtensor <[22 ,10 ],f16 > -> !torch.vtensor <[1 ,10 ],f16 >
4311
+ return %4 : !torch.vtensor <[1 ,10 ],f16 >
4312
+ }
4313
+
4314
+ // -----
4315
+ // CHECK-LABEL: func.func @torch.aten.mm$bf16
4316
+ // CHECK: tosa.matmul
4317
+ // CHECK-SAME: (tensor<1x1x22xbf16>, tensor<1x22x10xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1x1x10xf32>
4318
+ // CHECK-NOT: torch.aten.mm
4319
+ // CHECK: tosa.cast
4320
+ // CHECK-SAME: (tensor<1x10xf32>) -> tensor<1x10xbf16>
4321
+ func.func @torch.aten.mm$bf16 (%arg0: !torch.vtensor <[1 ,22 ],bf16 >, %arg1: !torch.vtensor <[22 ,10 ],bf16 >) -> !torch.vtensor <[1 ,10 ],bf16 > {
4322
+ %4 = torch.aten.mm %arg0 , %arg1 : !torch.vtensor <[1 ,22 ],bf16 >, !torch.vtensor <[22 ,10 ],bf16 > -> !torch.vtensor <[1 ,10 ],bf16 >
4323
+ return %4 : !torch.vtensor <[1 ,10 ],bf16 >
4324
+ }
4325
+
4326
+ // -----
4327
+ // CHECK-LABEL: func.func @torch.aten.matmul$broadcast(
4328
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[10,3,4],f32>,
4329
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> {
4330
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[4],f32> -> tensor<4xf32>
4331
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32>
4332
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4333
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<4xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4334
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 30, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4335
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<10x3x4xf32>, !tosa.shape<3>) -> tensor<1x30x4xf32>
4336
+ // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_RESHAPE]] {perms = array<i32: 1, 0, 2>} : (tensor<1x4x1xf32>) -> tensor<4x1x1xf32>
4337
+ // CHECK: %[[WTS_SHAPE_2:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
4338
+ // CHECK: %[[WTS_RESHAPE_2:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE_2]] : (tensor<4x1x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32>
4339
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4340
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4341
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE_2]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x30x4xf32>, tensor<1x4x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x30x1xf32>
4342
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[10, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4343
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x30x1xf32>, !tosa.shape<2>) -> tensor<10x3xf32>
4344
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<10x3xf32> -> !torch.vtensor<[10,3],f32>
4345
+ // CHECK: return %[[RES]]
4346
+ func.func @torch.aten.matmul$broadcast (%arg0: !torch.vtensor <[10 ,3 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ],f32 >) -> !torch.vtensor <[10 ,3 ],f32 > {
4347
+ %0 = torch.aten.matmul %arg0 , %arg1 : !torch.vtensor <[10 ,3 ,4 ],f32 >, !torch.vtensor <[4 ],f32 > -> !torch.vtensor <[10 ,3 ],f32 >
4348
+ return %0 : !torch.vtensor <[10 ,3 ],f32 >
4349
+ }
4350
+
4351
+ // -----
4352
+ // CHECK-LABEL: func.func @torch.aten.linear$f16(
4353
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[2,4],f16>,
4354
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[3,4],f16>,
4355
+ // CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4356
+ // CHECK: %[[BIAS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[BIAS]] : !torch.vtensor<[3],f16> -> tensor<3xf16>
4357
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16>
4358
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16>
4359
+ // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array<i32: 1, 0>} : (tensor<3x4xf16>) -> tensor<4x3xf16>
4360
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4361
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16>
4362
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
4363
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16>
4364
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4365
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4366
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32>
4367
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4368
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4369
+ // CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16>
4370
+ // CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4371
+ // CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16>
4372
+ // CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16>
4373
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16>
4374
+ // CHECK: return %[[RES]]
4375
+ func.func @torch.aten.linear$f16 (%arg0: !torch.vtensor <[2 ,4 ],f16 >, %arg1: !torch.vtensor <[3 ,4 ],f16 >, %arg2: !torch.vtensor <[3 ],f16 >) -> !torch.vtensor <[2 ,3 ],f16 > {
4376
+ %0 = torch.aten.linear %arg0 , %arg1 , %arg2 : !torch.vtensor <[2 ,4 ],f16 >, !torch.vtensor <[3 ,4 ],f16 >, !torch.vtensor <[3 ],f16 > -> !torch.vtensor <[2 ,3 ],f16 >
4377
+ return %0 : !torch.vtensor <[2 ,3 ],f16 >
4378
+ }
0 commit comments