Skip to content

Commit 232e854

Browse files
authored
[mlir-gen] Add mixed precision type support(1/N). (#1075)
Adds creation of gemm kernels with mixed types such as bf16 -> fp32, f16 -> f32, i8 -> i32.
1 parent fc178c2 commit 232e854

File tree

4 files changed

+168
-23
lines changed

4 files changed

+168
-23
lines changed

test/Integration/mlir-gen-matmul.mlir

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
// RUN: mlir-gen --kernel=args --seed=0 --float-type=bf16 --batch=128 --layers=2304,768 --tiles=64,48,64 2>&1 | FileCheck %s --check-prefix=BF16
33
// RUN: mlir-gen --kernel=args --seed=0 --float-type=f16 --batch=128 --layers=2304,768 --tiles=64,48,64 2>&1 | FileCheck %s --check-prefix=FP16
44

5+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-bf16 --batch=128 --layers=2304,768 --tiles=64,48,64 2>&1 | FileCheck %s --check-prefix=MXBF16-GENERIC
6+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-i8 --batch=128 --layers=2304,768 --tiles=64,48,64 2>&1 | FileCheck %s --check-prefix=MXI8-GENERIC
7+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-f16 --batch=128 --layers=2304,768 --tiles=64,48,64 2>&1 | FileCheck %s --check-prefix=MXF16-GENERIC
8+
9+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-bf16 --batch=128 --layers=2304,768 --tiles=64,48,64 --output=contract 2>&1 | FileCheck %s --check-prefix=MXBF16-CONTRACT
10+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-i8 --batch=128 --layers=2304,768 --tiles=64,48,64 --output=contract 2>&1 | FileCheck %s --check-prefix=MXI8-CONTRACT
11+
// RUN: mlir-gen --kernel=args --seed=0 --float-type=mx-f16 --batch=128 --layers=2304,768 --tiles=64,48,64 --output=contract 2>&1 | FileCheck %s --check-prefix=MXF16-CONTRACT
12+
513
// FP32: // RUN{{.*}}tpp-run %s -n {{\d*}}
614
// FP32: // RUN{{.*}}-e entry -entry-point-result=void
715
// FP32: // BENCH_TOTAL_FLOPS: 452984832
@@ -40,3 +48,90 @@
4048
// FP16: arith.mulf
4149
// FP16: arith.addf
4250
// FP16-NOT: dealloc
51+
52+
// MXBF16-GENERIC: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
53+
// MXBF16-GENERIC: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
54+
// MXBF16-GENERIC: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
55+
// MXBF16-GENERIC-LABEL: func.func @entry(
56+
// MXBF16-GENERIC-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xbf16>,
57+
// MXBF16-GENERIC-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xbf16>,
58+
// MXBF16-GENERIC-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32> {
59+
// MXBF16-GENERIC: %[[VAL_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xbf16>, tensor<16x36x64x48xbf16>) outs(%[[ARG2]] : tensor<2x16x64x48xf32>) {
60+
// MXBF16-GENERIC: ^bb0(%[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: bf16, %[[VAL_3:.*]]: f32):
61+
// MXBF16-GENERIC: %[[VAL_4:.*]] = arith.extf %[[VAL_1]] : bf16 to f32
62+
// MXBF16-GENERIC: %[[VAL_5:.*]] = arith.extf %[[VAL_2]] : bf16 to f32
63+
// MXBF16-GENERIC: %[[VAL_6:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
64+
// MXBF16-GENERIC: %[[VAL_7:.*]] = arith.addf %[[VAL_3]], %[[VAL_6]] : f32
65+
// MXBF16-GENERIC: linalg.yield %[[VAL_7]] : f32
66+
// MXBF16-GENERIC: } -> tensor<2x16x64x48xf32>
67+
// MXBF16-GENERIC: return %[[VAL_0]] : tensor<2x16x64x48xf32>
68+
// MXBF16-GENERIC: }
69+
70+
// MXI8-GENERIC: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
71+
// MXI8-GENERIC: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
72+
// MXI8-GENERIC: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
73+
// MXI8-GENERIC-LABEL: func.func @entry(
74+
// MXI8-GENERIC-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xi8>,
75+
// MXI8-GENERIC-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xi8>,
76+
// MXI8-GENERIC-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xi32>) -> tensor<2x16x64x48xi32> {
77+
// MXI8-GENERIC: %[[VAL_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xi8>, tensor<16x36x64x48xi8>) outs(%[[ARG2]] : tensor<2x16x64x48xi32>) {
78+
// MXI8-GENERIC: ^bb0(%[[VAL_1:.*]]: i8, %[[VAL_2:.*]]: i8, %[[VAL_3:.*]]: i32):
79+
// MXI8-GENERIC: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : i8 to i32
80+
// MXI8-GENERIC: %[[VAL_5:.*]] = arith.extsi %[[VAL_2]] : i8 to i32
81+
// MXI8-GENERIC: %[[VAL_6:.*]] = arith.muli %[[VAL_4]], %[[VAL_5]] : i32
82+
// MXI8-GENERIC: %[[VAL_7:.*]] = arith.addi %[[VAL_3]], %[[VAL_6]] : i32
83+
// MXI8-GENERIC: linalg.yield %[[VAL_7]] : i32
84+
// MXI8-GENERIC: } -> tensor<2x16x64x48xi32>
85+
// MXI8-GENERIC: return %[[VAL_0]] : tensor<2x16x64x48xi32>
86+
// MXI8-GENERIC: }
87+
88+
// MXBF16-CONTRACT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
89+
// MXBF16-CONTRACT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
90+
// MXBF16-CONTRACT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
91+
// MXBF16-CONTRACT-LABEL: func.func @entry(
92+
// MXBF16-CONTRACT-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xbf16>,
93+
// MXBF16-CONTRACT-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xbf16>,
94+
// MXBF16-CONTRACT-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32> {
95+
// MXBF16-CONTRACT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xbf16>, tensor<16x36x64x48xbf16>) outs(%[[ARG2]] : tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32>
96+
// MXBF16-CONTRACT: return %[[VAL_0]] : tensor<2x16x64x48xf32>
97+
// MXBF16-CONTRACT: }
98+
99+
// MXI8-CONTRACT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
100+
// MXI8-CONTRACT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
101+
// MXI8-CONTRACT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
102+
// MXI8-CONTRACT-LABEL: func.func @entry(
103+
// MXI8-CONTRACT-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xi8>,
104+
// MXI8-CONTRACT-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xi8>,
105+
// MXI8-CONTRACT-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xi32>) -> tensor<2x16x64x48xi32> {
106+
// MXI8-CONTRACT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xi8>, tensor<16x36x64x48xi8>) outs(%[[ARG2]] : tensor<2x16x64x48xi32>) -> tensor<2x16x64x48xi32>
107+
// MXI8-CONTRACT: return %[[VAL_0]] : tensor<2x16x64x48xi32>
108+
// MXI8-CONTRACT: }
109+
110+
// MXF16-GENERIC: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
111+
// MXF16-GENERIC: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
112+
// MXF16-GENERIC: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
113+
// MXF16-GENERIC-LABEL: func.func @entry(
114+
// MXF16-GENERIC-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xf16>,
115+
// MXF16-GENERIC-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xf16>,
116+
// MXF16-GENERIC-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32> {
117+
// MXF16-GENERIC: %[[VAL_0:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xf16>, tensor<16x36x64x48xf16>) outs(%[[ARG2]] : tensor<2x16x64x48xf32>) {
118+
// MXF16-GENERIC: ^bb0(%[[VAL_1:.*]]: f16, %[[VAL_2:.*]]: f16, %[[VAL_3:.*]]: f32):
119+
// MXF16-GENERIC: %[[VAL_4:.*]] = arith.extf %[[VAL_1]] : f16 to f32
120+
// MXF16-GENERIC: %[[VAL_5:.*]] = arith.extf %[[VAL_2]] : f16 to f32
121+
// MXF16-GENERIC: %[[VAL_6:.*]] = arith.mulf %[[VAL_4]], %[[VAL_5]] : f32
122+
// MXF16-GENERIC: %[[VAL_7:.*]] = arith.addf %[[VAL_3]], %[[VAL_6]] : f32
123+
// MXF16-GENERIC: linalg.yield %[[VAL_7]] : f32
124+
// MXF16-GENERIC: } -> tensor<2x16x64x48xf32>
125+
// MXF16-GENERIC: return %[[VAL_0]] : tensor<2x16x64x48xf32>
126+
// MXF16-GENERIC: }
127+
128+
// MXF16-CONTRACT: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
129+
// MXF16-CONTRACT: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
130+
// MXF16-CONTRACT: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
131+
// MXF16-CONTRACT-LABEL: func.func @entry(
132+
// MXF16-CONTRACT-SAME: %[[ARG0:.*]]: tensor<2x36x64x64xf16>,
133+
// MXF16-CONTRACT-SAME: %[[ARG1:.*]]: tensor<16x36x64x48xf16>,
134+
// MXF16-CONTRACT-SAME: %[[ARG2:.*]]: tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32> {
135+
// MXF16-CONTRACT: %[[VAL_0:.*]] = linalg.contract indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[ARG0]], %[[ARG1]] : tensor<2x36x64x64xf16>, tensor<16x36x64x48xf16>) outs(%[[ARG2]] : tensor<2x16x64x48xf32>) -> tensor<2x16x64x48xf32>
136+
// MXF16-CONTRACT: return %[[VAL_0]] : tensor<2x16x64x48xf32>
137+
// MXF16-CONTRACT: }

tools/mlir-gen/MLIRGen.cpp

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,27 @@ MLIRGenerator::MLIRGenerator(StringRef outputOpKindStr, StringRef kernelStr,
114114
"Must have 3 tile sizes (or none)");
115115

116116
// Pick data type
117-
auto elementType = llvm::StringSwitch<std::optional<Type>>(targetType)
118-
.CaseLower("f32", builder.getF32Type())
119-
.CaseLower("f16", builder.getF16Type())
120-
.CaseLower("bf16", builder.getBF16Type())
121-
.Default(std::nullopt);
117+
auto elementType =
118+
llvm::StringSwitch<std::optional<SmallVector<mlir::Type>>>(targetType)
119+
.CaseLower("f32", SmallVector<Type>{builder.getF32Type(),
120+
builder.getF32Type()})
121+
.CaseLower("f16", SmallVector<Type>{builder.getF16Type(),
122+
builder.getF16Type()})
123+
.CaseLower("bf16", SmallVector<Type>{builder.getBF16Type(),
124+
builder.getBF16Type()})
125+
.CaseLower("mx-bf16", SmallVector<Type>{builder.getBF16Type(),
126+
builder.getF32Type()})
127+
.CaseLower("mx-f16", SmallVector<Type>{builder.getF16Type(),
128+
builder.getF32Type()})
129+
.CaseLower("mx-i8", SmallVector<Type>{builder.getIntegerType(8),
130+
builder.getI32Type()})
131+
.Default(std::nullopt);
122132
assert(elementType && "Unsupported data type");
123-
dataType = *elementType;
133+
dataTypes.push_back((*elementType)[0]);
134+
dataTypes.push_back((*elementType)[1]);
124135

125136
// Disable VNNI packing if it is not a F16/BF16 data type
126-
if (!dataType.isBF16() && !dataType.isF16())
137+
if (!dataTypes[0].isBF16() && !dataTypes[0].isF16())
127138
vnniFactor = 0;
128139
assert(((vnniFactor >= 0) && (vnniFactor % 2 == 0)) &&
129140
"Invalid VNNI packing factor");
@@ -437,9 +448,45 @@ Value MLIRGenerator::lowerGenericMatmul(Value input, Value weight,
437448
auto arg0 = blockArgs[0];
438449
auto arg1 = blockArgs[1];
439450
auto arg2 = blockArgs[2];
440-
auto mul = nestedBuilder.create<arith::MulFOp>(loc, arg0, arg1);
441-
auto add = nestedBuilder.create<arith::AddFOp>(loc, arg2, mul);
442-
nestedBuilder.create<linalg::YieldOp>(loc, ValueRange{add});
451+
// If input and output type differs, up cast input to output
452+
// type using arith.extf/arith.extsi.
453+
Type inputElementType =
454+
cast<ShapedType>(input.getType()).getElementType();
455+
Type weightElementType =
456+
cast<ShapedType>(weight.getType()).getElementType();
457+
Type outputElementType =
458+
cast<ShapedType>(output.getType()).getElementType();
459+
if (inputElementType != outputElementType) {
460+
if (inputElementType.isFloat()) {
461+
arg0 = nestedBuilder.create<arith::ExtFOp>(
462+
loc, outputElementType, arg0);
463+
} else {
464+
arg0 = nestedBuilder.create<arith::ExtSIOp>(
465+
loc, outputElementType, arg0);
466+
}
467+
}
468+
469+
if (weightElementType != outputElementType) {
470+
if (weightElementType.isFloat()) {
471+
arg1 = nestedBuilder.create<arith::ExtFOp>(
472+
loc, outputElementType, arg1);
473+
} else {
474+
arg1 = nestedBuilder.create<arith::ExtSIOp>(
475+
loc, outputElementType, arg1);
476+
}
477+
}
478+
479+
auto *mul =
480+
outputElementType.isFloat()
481+
? nestedBuilder.create<arith::MulFOp>(loc, arg0, arg1)
482+
: nestedBuilder.create<arith::MulIOp>(loc, arg0, arg1);
483+
auto *add = outputElementType.isFloat()
484+
? nestedBuilder.create<arith::AddFOp>(
485+
loc, arg2, mul->getResult(0))
486+
: nestedBuilder.create<arith::AddIOp>(
487+
loc, arg2, mul->getResult(0));
488+
nestedBuilder.create<linalg::YieldOp>(
489+
loc, ValueRange{add->getResults()});
443490
})
444491
.getResult(0);
445492

@@ -520,7 +567,7 @@ Value MLIRGenerator::lowerNamedRelu(Value input, Value output) {
520567
return input;
521568

522569
auto outTy = cast<ShapedType>(input.getType());
523-
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataType));
570+
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataTypes[0]));
524571
Value emptyTensor = builder.create<tensor::EmptyOp>(loc, outTy, ValueRange{});
525572
auto fill =
526573
builder.create<linalg::FillOp>(loc, zero, emptyTensor)->getResult(0);
@@ -538,7 +585,7 @@ Value MLIRGenerator::lowerRelu(Value input, Value output) {
538585
if (!enableRelu)
539586
return input;
540587

541-
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataType));
588+
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataTypes[0]));
542589
auto outTy = cast<ShapedType>(input.getType());
543590
auto map = getMap(input, MAP_PARALLEL);
544591
auto relu =
@@ -602,7 +649,7 @@ Value MLIRGenerator::lowerSoftmax(Value input, Value output) {
602649
auto redTy = getShape(dims, PACK_OUTPUT);
603650
Value redTensor =
604651
builder.create<tensor::EmptyOp>(loc, dims, outTy.getElementType());
605-
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataType));
652+
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataTypes[0]));
606653
auto fill = builder.create<linalg::FillOp>(loc, zero, redTensor);
607654
auto redux = builder.create<linalg::GenericOp>(
608655
loc, redTy, ValueRange{exp.getResult(0)}, ValueRange{fill.getResult(0)},
@@ -651,11 +698,13 @@ Value MLIRGenerator::lowerSoftmax(Value input, Value output) {
651698
TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
652699
// Already packed type, just return ND tensor
653700
if (dims.size() > 2)
654-
return RankedTensorType::get(dims, dataType);
701+
return RankedTensorType::get(dims, type == PACK_OUTPUT ? dataTypes[1]
702+
: dataTypes[0]);
655703

656704
// Unpacked type, just return 2D tensor
657705
if (!tiles.size())
658-
return RankedTensorType::get(dims, dataType);
706+
return RankedTensorType::get(dims, type == PACK_OUTPUT ? dataTypes[1]
707+
: dataTypes[0]);
659708

660709
// Packed types block by tile size
661710
assert(tiles.size() == 3 && "Invalid tile size format");
@@ -671,7 +720,7 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
671720
assert(x % n == 0 && "Invalid tile size for N dim");
672721
assert(y % c == 0 && "Invalid tile size for C dim");
673722
// N x C -> BN x BC x bn x bc
674-
return RankedTensorType::get({x / n, y / c, n, c}, dataType);
723+
return RankedTensorType::get({x / n, y / c, n, c}, dataTypes[0]);
675724
case PACK_WEIGHT:
676725
// VNNI packing can be done via tpp-opt --vnni-pack
677726
assert(x % k == 0 && "Invalid tile size for K dim");
@@ -680,20 +729,20 @@ TensorType MLIRGenerator::getShape(ArrayRef<int64_t> dims, PackingType type) {
680729
// VNNI: C x K -> BK x BC x bc/vnni x bk x vnni
681730
if (vnniFactor != 0)
682731
return RankedTensorType::get(
683-
{y / k, x / c, c / vnniFactor, k, vnniFactor}, dataType);
732+
{y / k, x / c, c / vnniFactor, k, vnniFactor}, dataTypes[0]);
684733

685734
// C x K -> BK x BC x bc x bk
686-
return RankedTensorType::get({y / k, x / c, c, k}, dataType);
735+
return RankedTensorType::get({y / k, x / c, c, k}, dataTypes[0]);
687736
case PACK_OUTPUT:
688737
assert(x % n == 0 && "Invalid tile size for N dim");
689738

690739
// Broadcast 1D -> 2D is Bk x bk only
691740
if (!y)
692-
return RankedTensorType::get({x / k, k}, dataType);
741+
return RankedTensorType::get({x / k, k}, dataTypes[1]);
693742

694743
// N x K -> BN x BK x bn x bk
695744
assert(y % k == 0 && "Invalid tile size for K dim");
696-
return RankedTensorType::get({x / n, y / k, n, k}, dataType);
745+
return RankedTensorType::get({x / n, y / k, n, k}, dataTypes[1]);
697746
}
698747

699748
llvm_unreachable("Unknown packing type");
@@ -838,7 +887,7 @@ int MLIRGenerator::getRand() {
838887
}
839888

840889
Value MLIRGenerator::getZeroInitTensor(TensorType type) {
841-
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataType));
890+
auto zero = getConstFloat(builder, 0.0, cast<FloatType>(dataTypes[0]));
842891
Value tensor =
843892
builder.create<tensor::EmptyOp>(loc, type, ValueRange{}).getResult();
844893
tensor = builder.create<linalg::FillOp>(loc, zero, tensor).getResult(0);

tools/mlir-gen/MLIRGen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class MLIRGenerator {
5252
SmallVector<int64_t> tiles;
5353

5454
/// Data type (element type of all tensors)
55-
Type dataType;
55+
SmallVector<Type> dataTypes;
5656

5757
/// Random seed
5858
int seed;

tools/mlir-gen/mlir-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ llvm::cl::opt<std::string>
6868
// Float type
6969
llvm::cl::opt<std::string>
7070
floatType("float-type", llvm::cl::desc("Float type and its bitsize"),
71-
llvm::cl::value_desc("f32|f16|bf16"), llvm::cl::init("f32"));
71+
llvm::cl::value_desc("f32|f16|bf16|mx-bf16|mx-f16|mx-i8"),
72+
llvm::cl::init("f32"));
7273

7374
// Random seed
7475
llvm::cl::opt<int> seed("seed", llvm::cl::desc("Random seed"),

0 commit comments

Comments
 (0)