diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9ffb7c1dc0f3..89ec9a599e4b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -505,9 +505,11 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; - // TODO: Improve usage of static shape information. - SmallVector lhsTargetShape(lhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector lhsTargetShape = + llvm::to_vector(llvm::map_range(lhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); + auto lhsBroadcastType = RankedTensorType::get( lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( @@ -516,8 +518,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } - SmallVector rhsTargetShape(rhsBroadcastToShape.size(), - ShapedType::kDynamic); + SmallVector rhsTargetShape = + llvm::to_vector(llvm::map_range(rhsBroadcastToShape, [](Value v) { + return getConstantIntValue(v).value_or(ShapedType::kDynamic); + })); auto rhsBroadcastType = RankedTensorType::get( rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 262f6e646bdd..5895d19a2ec1 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -43,6 +43,70 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.4d +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,2,32,400],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,2,32,400],f32> -> tensor<1x2x32x400xf32> +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[1,2,400,32],f32> -> tensor<1x2x400x32xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_12:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 400 : index +// CHECK: %[[VAL_14:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_15:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_17:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_18:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_19:.*]] = arith.constant 400 : index +// CHECK: %[[VAL_20:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_21:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_21]] : i64 +// CHECK: cf.assert %[[VAL_22]], "mismatching contracting dimension" +// CHECK: %[[VAL_23:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_24:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_25:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_26:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_27:.*]] = arith.constant 400 : i64 +// CHECK: %[[VAL_28:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_29:.*]] = arith.constant 32 : i64 +// CHECK: %[[VAL_30:.*]] = arith.constant 400 : i64 +// CHECK: %[[VAL_31:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_32:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_33:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1x2x400x32xf32> +// CHECK: %[[VAL_35:.*]] = tensor.cast %[[VAL_1]] : tensor<1x2x400x32xf32> to tensor<1x2x400x32xf32> +// CHECK: %[[VAL_36:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_37:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_38:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_39:.*]] = tensor.empty() : tensor<1x2x32x400xf32> +// CHECK: %[[VAL_40:.*]] = tensor.cast %[[VAL_0]] : tensor<1x2x32x400xf32> to tensor<1x2x32x400xf32> +// CHECK: %[[VAL_41:.*]] = tensor.collapse_shape %[[VAL_35]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x400x32xf32> into tensor<2x400x32xf32> +// CHECK: %[[VAL_42:.*]] = tensor.collapse_shape %[[VAL_40]] {{\[\[}}0, 1], [2], [3]] : tensor<1x2x32x400xf32> into tensor<2x32x400xf32> +// CHECK: %[[VAL_43:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_44:.*]] = tensor.empty() : tensor<2x400x400xf32> +// CHECK: %[[VAL_45:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : f32) outs(%[[VAL_44]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[VAL_47:.*]] = linalg.batch_matmul ins(%[[VAL_41]], %[[VAL_42]] : tensor<2x400x32xf32>, tensor<2x32x400xf32>) outs(%[[VAL_46]] : tensor<2x400x400xf32>) -> tensor<2x400x400xf32> +// CHECK: %[[VAL_48:.*]] = tensor.expand_shape %[[VAL_47]] {{\[\[}}0, 1], [2], [3]] output_shape [1, 2, 400, 400] : tensor<2x400x400xf32> into tensor<1x2x400x400xf32> +// CHECK: %[[VAL_49:.*]] = tensor.cast %[[VAL_48]] : tensor<1x2x400x400xf32> to tensor<1x2x400x400xf32> +// CHECK: %[[VAL_50:.*]] = torch_c.from_builtin_tensor %[[VAL_49]] : tensor<1x2x400x400xf32> -> !torch.vtensor<[1,2,400,400],f32> +// CHECK: return %[[VAL_50]] : !torch.vtensor<[1,2,400,400],f32> +// CHECK: } +func.func @torch.aten.matmul.4d(%arg0: !torch.vtensor<[1,2,32,400],f32>, %arg1: !torch.vtensor<[1,2,400,32],f32>) -> !torch.vtensor<[1,2,400,400],f32> { + %0 = torch.aten.matmul %arg1, %arg0 : !torch.vtensor<[1,2,400,32],f32>, !torch.vtensor<[1,2,32,400],f32> -> !torch.vtensor<[1,2,400,400],f32> + return %0 : !torch.vtensor<[1,2,400,400],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>