Skip to content

Commit 8282aab

Browse files
lhutton1tatwaichong
authored andcommitted
[mlir][tosa] Add support for cast_from/to_block_scaled (llvm#163436)
This commit adds support for the cast_from/to_block_scaled operations from the ext-mxfp extension. This includes: - Operation definition in TosaOps.td - Micro-scaling supported types definition - Shape inference and verifiers - Validation pass checks to ensure usage is only valid when the target environment includes ext-mxfp and at least v1.1.draft of the specification. Note: currently it excludes support for mxint8. This will be added in a later commit. Note: this commit adds support as defined in the spec in arm/tosa-specification@063846a. EXT_MXFP extension is considered experimental and subject to breaking change. Co-authored-by: Tat Wai Chong <[email protected]>
1 parent fadab0d commit 8282aab

15 files changed

+529
-35
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,34 @@ extensionComplianceMap = {
864864
{{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
865865
{{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}},
866866
allOf}}},
867+
{"tosa.cast_from_block_scaled",
868+
{{{Extension::bf16, Extension::mxfp},
869+
{{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
870+
{{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
871+
{{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
872+
{{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
873+
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
874+
allOf},
875+
{{Extension::mxfp},
876+
{{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
877+
{{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
878+
{{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
879+
{{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
880+
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
881+
{"tosa.cast_to_block_scaled",
882+
{{{Extension::mxfp},
883+
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
884+
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
885+
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
886+
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
887+
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
888+
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
889+
{{Extension::bf16, Extension::mxfp},
890+
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
891+
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
892+
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
893+
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}},
894+
allOf}}},
867895
{"tosa.rescale",
868896
{{{Extension::int16},
869897
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,69 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
24702470
let hasFolder = 1;
24712471
}
24722472

2473+
//===----------------------------------------------------------------------===//
2474+
// Operator: cast_from_block_scaled
2475+
//===----------------------------------------------------------------------===//
2476+
def Tosa_CastFromBlockScaledOp: Tosa_InferShapedTypeOp<"cast_from_block_scaled"> {
2477+
let summary = "Apply scales from a scale tensor to the values in a value tensor";
2478+
2479+
let description = [{
2480+
Apply the scales from a scale tensor to the values in a value tensor, casting
2481+
the result to the output type. The block dimension must be the last dimension
2482+
of the tensor.
2483+
}];
2484+
2485+
let arguments = (ins
2486+
Tosa_MXFPDataTensorAtLeast1D:$input_data,
2487+
Tosa_MXFPScaleTensorAtLeast1D:$input_scale,
2488+
Tosa_BlockSizeAttr:$block_size
2489+
);
2490+
2491+
let results = (outs
2492+
Tosa_TensorAtLeast1D: $output_data
2493+
);
2494+
2495+
list<Availability> availability = [
2496+
Profile<[Tosa_PRO_FP]>,
2497+
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
2498+
];
2499+
2500+
let hasVerifier = 1;
2501+
let hasCustomAssemblyFormat = 1;
2502+
}
2503+
2504+
//===----------------------------------------------------------------------===//
2505+
// Operator: cast_to_block_scaled
2506+
//===----------------------------------------------------------------------===//
2507+
def Tosa_CastToBlockScaledOp : Tosa_InferShapedTypeOp<"cast_to_block_scaled"> {
2508+
let summary = "Calculate scale tensor values per block, output to separate scale and data tensors.";
2509+
2510+
let description = [{
2511+
Calculate a scale value per block of input values and use that to calculate
2512+
scaled data values from an input tensor. The output tensors are cast to the
2513+
specified scale and value types. The block dimension will be the last dimension
2514+
of the tensor.
2515+
}];
2516+
2517+
let arguments = (ins
2518+
Tosa_TensorAtLeast1D:$input_data,
2519+
Tosa_BlockSizeAttr:$block_size
2520+
);
2521+
2522+
let results = (outs
2523+
Tosa_MXFPDataTensorAtLeast1D:$output_data,
2524+
Tosa_MXFPScaleTensorAtLeast1D:$output_scale
2525+
);
2526+
2527+
list<Availability> availability = [
2528+
Profile<[Tosa_PRO_FP]>,
2529+
Extension<[Tosa_EXT_BF16, Tosa_EXT_MXFP]>
2530+
];
2531+
2532+
let hasVerifier = 1;
2533+
let hasCustomAssemblyFormat = 1;
2534+
}
2535+
24732536
//===----------------------------------------------------------------------===//
24742537
// Operator: rescale
24752538
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ class ProfileInfoDepot {
7979

8080
LogicalResult populatationDispatch(Operation *op);
8181

82-
LogicalResult populateProfileInfo(ValueRange operands, Value output);
82+
// Add input operands and output results to the profile type info list
83+
LogicalResult populateProfileInfo(ValueRange operands, ValueRange results);
8384

8485
// Base
8586
template <typename T>

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,16 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
199199
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
200200
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
201201
]>;
202+
def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
203+
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
204+
TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
205+
"tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
206+
>;
207+
def Tosa_MXFPScaleTensorAtLeast1D : AnyTypeOf<[
208+
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
209+
TosaRankedTensorOf<[Tosa_MXFPScaleNumber], [AtLeastRankOne]>],
210+
"tosa-conformant tensor of at least rank 1", "::mlir::TensorType"
211+
>;
202212

203213
//===----------------------------------------------------------------------===//
204214
// Generic scalar, vector, or tensor of a particular type.

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
370370
result.operands)))
371371
return failure();
372372

373-
result.addTypes(fnTy.getResult(0));
373+
result.addTypes(fnTy.getResults());
374374
result.addAttributes(attrs);
375375

376376
return success();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
532532
printWithEnumHandling(parser, *this);
533533
}
534534

535+
ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
536+
OperationState &result) {
537+
return parseWithEnumHandling<tosa::BlockSize>(parser, result);
538+
}
539+
540+
void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
541+
printWithEnumHandling(parser, *this);
542+
}
543+
544+
ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
545+
OperationState &result) {
546+
return parseWithEnumHandling<tosa::BlockSize>(parser, result);
547+
}
548+
549+
void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
550+
printWithEnumHandling(parser, *this);
551+
}
552+
535553
//===----------------------------------------------------------------------===//
536554
// Tosa utilities.
537555
//===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
39443962
return success();
39453963
}
39463964

3965+
LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3966+
MLIRContext *context, ::std::optional<Location> location,
3967+
CastFromBlockScaledOp::Adaptor adaptor,
3968+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3969+
const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3970+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3971+
return success();
3972+
}
3973+
3974+
LogicalResult CastFromBlockScaledOp::verify() {
3975+
const Type inputDataType = getInputData().getType();
3976+
const Type outputDataType = getResult().getType();
3977+
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
3978+
return emitOpError() << "require compatible shapes for input_data ("
3979+
<< inputDataType << ") and "
3980+
<< "output_data (" << outputDataType << ")";
3981+
3982+
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
3983+
3984+
if (inputDataShape.hasRank()) {
3985+
const unsigned int blockSize =
3986+
BlockSizeAttr::getBlockSizeValue(getBlockSize());
3987+
const int64_t inputDataLastDim =
3988+
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
3989+
if (inputDataLastDim % blockSize != 0)
3990+
return emitOpError() << "expect last dimension of input_data ("
3991+
<< inputDataLastDim
3992+
<< ") to be divisible by block_size (" << blockSize
3993+
<< ")";
3994+
3995+
const Type inputScaleType = getInputScale().getType();
3996+
const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
3997+
3998+
if (inputScaleShape.hasRank()) {
3999+
SmallVector<int64_t> inputDataDims, inputScaleDims;
4000+
inputDataShape.getDims(inputDataDims);
4001+
inputScaleShape.getDims(inputScaleDims);
4002+
4003+
if (inputDataDims.size() != inputScaleDims.size() ||
4004+
failed(verifyCompatibleShape(
4005+
ArrayRef<int64_t>(inputDataDims).drop_back(1),
4006+
ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4007+
return emitOpError() << "require compatible shapes for input_data ("
4008+
<< inputDataType << ") and "
4009+
<< "input_scale (" << inputScaleType
4010+
<< ") except for the last dimension";
4011+
4012+
const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4013+
inputScaleDims.back()};
4014+
if (ShapedType::isStatic(inputDataLastDim) &&
4015+
failed(verifyCompatibleDims(dimsToCheck)))
4016+
return emitOpError()
4017+
<< "expect last dimension of input_scale ("
4018+
<< inputScaleDims.back()
4019+
<< ") to be equal to last dimension of input_data / block_size ("
4020+
<< inputDataDims.back() / blockSize << ")";
4021+
}
4022+
}
4023+
4024+
return success();
4025+
}
4026+
4027+
LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4028+
MLIRContext *context, ::std::optional<Location> location,
4029+
CastToBlockScaledOp::Adaptor adaptor,
4030+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4031+
const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4032+
inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4033+
if (!inputShape.hasRank())
4034+
return success();
4035+
4036+
// Calculate output_scale shape if ranked input provided
4037+
SmallVector<int64_t> outputScaleShape;
4038+
inputShape.getDims(outputScaleShape);
4039+
const int64_t lastDimLoc = inputShape.getRank() - 1;
4040+
const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4041+
if (ShapedType::isStatic(lastDimSize)) {
4042+
const unsigned int blockSize =
4043+
BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4044+
outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4045+
}
4046+
inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4047+
return success();
4048+
}
4049+
4050+
LogicalResult CastToBlockScaledOp::verify() {
4051+
const Type inputDataType = getInputData().getType();
4052+
const Type outputDataType = getResult(0).getType();
4053+
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4054+
return emitOpError() << "require compatible shapes for input_data ("
4055+
<< inputDataType << ") and "
4056+
<< "output_data (" << outputDataType << ")";
4057+
4058+
const unsigned int blockSize =
4059+
BlockSizeAttr::getBlockSizeValue(getBlockSize());
4060+
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4061+
if (inputDataShape.hasRank()) {
4062+
const int64_t inputDataLastDim =
4063+
inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4064+
if (ShapedType::isStatic(inputDataLastDim) &&
4065+
inputDataLastDim % blockSize != 0)
4066+
return emitOpError() << "expect last dimension of input_data ("
4067+
<< inputDataLastDim
4068+
<< ") to be divisible by block_size (" << blockSize
4069+
<< ")";
4070+
}
4071+
4072+
const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4073+
const Type outputScaleType = getResult(1).getType();
4074+
const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4075+
if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4076+
SmallVector<int64_t> outputDataDims, outputScaleDims;
4077+
outputDataShape.getDims(outputDataDims);
4078+
outputScaleShape.getDims(outputScaleDims);
4079+
4080+
if (outputDataDims.size() != outputScaleDims.size() ||
4081+
failed(verifyCompatibleShape(
4082+
ArrayRef<int64_t>(outputDataDims).drop_back(1),
4083+
ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4084+
return emitOpError() << "require compatible shapes for output_data ("
4085+
<< outputDataType << ") and "
4086+
<< "output_scale (" << outputScaleType
4087+
<< ") except for the last dimension";
4088+
4089+
const int64_t outputDataLastDim = outputDataDims.back();
4090+
const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4091+
outputScaleDims.back()};
4092+
if (ShapedType::isStatic(outputDataLastDim) &&
4093+
failed(verifyCompatibleDims(dimsToCheck)))
4094+
return emitOpError()
4095+
<< "expect last dimension of output_scale ("
4096+
<< outputScaleDims.back()
4097+
<< ") to be equal to last dimension of output_data / block_size ("
4098+
<< outputDataDims.back() / blockSize << ")";
4099+
}
4100+
4101+
return success();
4102+
}
4103+
39474104
LogicalResult IfOp::inferReturnTypeComponents(
39484105
MLIRContext *context, ::std::optional<Location> location,
39494106
IfOp::Adaptor adaptor,

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
5151

5252
// Base populating function
5353
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
54-
Value output) {
55-
for (auto operand : operands)
54+
ValueRange results) {
55+
for (const auto &operand : operands)
5656
addValue(operand);
57-
addValue(output);
57+
for (const auto &result : results)
58+
addValue(result);
5859
return success();
5960
}
6061

@@ -176,23 +177,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
176177
return success();
177178
}
178179

179-
template <>
180-
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
181-
addValue(op.getInputReal());
182-
addValue(op.getInputImag());
183-
addValue(op.getOutputReal());
184-
addValue(op.getOutputImag());
185-
return success();
186-
}
187-
188-
template <>
189-
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
190-
addValue(op.getInputReal());
191-
addValue(op.getOutputReal());
192-
addValue(op.getOutputImag());
193-
return success();
194-
}
195-
196180
template <>
197181
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
198182
addValue(op.getOnTrue());
@@ -246,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
246230
// This helper function populates the info for all operands.
247231
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
248232
if (isa<tosa::tosaOp##Op>(op)) { \
249-
return populateProfileInfo(op->getOperands(), op->getResult(0)); \
233+
return populateProfileInfo(op->getOperands(), op->getResults()); \
250234
}
251235

252236
// Skip irrelevant operands when they are independent and not tied to any
@@ -257,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
257241
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
258242
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
259243
POPULATE_PROFILE_INFO_CUSTOM(Mul)
260-
POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
261-
POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
262244
POPULATE_PROFILE_INFO_CUSTOM(Concat)
263245
POPULATE_PROFILE_INFO_CUSTOM(Pad)
264246
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -277,7 +259,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
277259
// For the most of tosa operators, all operands are profile/extension related
278260
// and hence are all considered in this profile-based compilance check.
279261
POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
262+
POPULATE_PROFILE_INFO_COMMON(FFT2d)
263+
POPULATE_PROFILE_INFO_COMMON(RFFT2d)
280264
POPULATE_PROFILE_INFO_COMMON(Cast)
265+
POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
266+
POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
281267
POPULATE_PROFILE_INFO_COMMON(Const)
282268
POPULATE_PROFILE_INFO_COMMON(ArgMax)
283269
POPULATE_PROFILE_INFO_COMMON(Sub)

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
635635
CHECK_RANKS_AND_SIZES(Transpose);
636636
// Type Conversion
637637
CHECK_RANKS_AND_SIZES(Cast);
638+
CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
639+
CHECK_RANKS_AND_SIZES(CastToBlockScaled);
638640
CHECK_RANKS_AND_SIZES(Rescale);
639641
// Control Flow Operators
640642
CHECK_RANKS_AND_SIZES(If);

0 commit comments

Comments
 (0)