Skip to content

Commit fd22351

Browse files
committed
llvm::
Created using spr 1.3.4
2 parents b9a662d + 2619c2e commit fd22351

File tree

21 files changed

+362
-246
lines changed

21 files changed

+362
-246
lines changed

llvm/lib/CodeGen/ModuloSchedule.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ void ModuloScheduleExpander::generateExistingPhis(
412412
InitVal, NewReg);
413413
auto It = VRMap[CurStageNum].find(LoopVal);
414414
if (It != VRMap[CurStageNum].end()) {
415-
llvm::Register Reg = It->second;
415+
Register Reg = It->second;
416416
VRMap[CurStageNum][Def] = Reg;
417417
}
418418
}

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 174 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,20 @@ def ROCDL_DPPUpdateOp : ROCDL_IntrOp<"update.dpp", [], [0],
652652
}];
653653
}
654654

655+
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
656+
BuildableType<"::mlir::VectorType::get("
657+
"{2},$_builder.getI16Type())">;
658+
659+
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
660+
BuildableType<"::mlir::VectorType::get("
661+
"{2},$_builder.getF16Type())">;
662+
663+
def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
664+
BuildableType<"::mlir::VectorType::get("
665+
"{2},$_builder.getBF16Type())">;
666+
667+
// TODO: The word and byte selectors are immarg in LLVM
668+
// update to be attributes in MLIR
655669
//===---------------------------------------------------------------------===//
656670
// 16-bit float intrinsics
657671
//===---------------------------------------------------------------------===//
@@ -667,10 +681,168 @@ def ROCDL_CvtPkRtz:
667681
}];
668682
}
669683

684+
def ROCDL_CvtScaleF32PkFp8F16 :
685+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
686+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
687+
let summary = "Scale and convert f16 to packed fp8";
688+
let description = [{
689+
Scale `src` by the exponent in `scale` then convert to packed fp8.
690+
Store the result in low/high word based on $wordSel, preserving the other word.
691+
}];
692+
let assemblyFormat = [{
693+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
694+
}];
695+
}
696+
697+
def ROCDL_CvtScaleF32PkFp8Bf16 :
698+
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
699+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
700+
let summary = "Scale and convert packed bf16 to packed fp8";
701+
let description = [{
702+
Scale `src` by the exponent in `scale` then convert to packed fp8.
703+
Store the result in low/high word based on $wordSel, preserving the other word.
704+
}];
705+
let assemblyFormat = [{
706+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
707+
}];
708+
}
709+
710+
711+
def ROCDL_CvtScaleF32PkBf8F16 :
712+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
713+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
714+
let summary = "Scale and convert f16 to packed bf8";
715+
let description = [{
716+
Scale `src` by the exponent in `scale` then convert to packed bf8.
717+
Store the result in low/high word based on $wordSel, preserving the other word.
718+
}];
719+
let assemblyFormat = [{
720+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
721+
}];
722+
}
723+
724+
725+
def ROCDL_CvtScaleF32PkBf8Bf16 :
726+
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
727+
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
728+
let summary = "Scale and convert bf16 to packed bf8";
729+
let description = [{
730+
Scale `src` by the exponent in `scale` then convert to packed bf8.
731+
Store the result in low/high word based on $wordSel, preserving the other word.
732+
}];
733+
let assemblyFormat = [{
734+
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
735+
}];
736+
}
737+
738+
def ROCDL_CvtScaleF32SrFp8F16 :
739+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
740+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
741+
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
742+
let description = [{
743+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
744+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
745+
746+
}];
747+
let assemblyFormat = [{
748+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
749+
}];
750+
}
751+
752+
def ROCDL_CvtScaleF32SrBf8F16 :
753+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
754+
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
755+
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
756+
let description = [{
757+
Scale `src` by the exponent in `scale` then convert to packed bf8 with stochastic rounding
758+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
759+
760+
}];
761+
let assemblyFormat = [{
762+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
763+
}];
764+
}
765+
766+
def ROCDL_CvtScaleF32SrFp8Bf16 :
767+
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
768+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
769+
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
770+
let description = [{
771+
Scale `src` by the exponent in `scale` then convert to packed fp8 with stochastic rounding
772+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
773+
774+
}];
775+
let assemblyFormat = [{
776+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
777+
}];
778+
}
779+
780+
def ROCDL_CvtScaleF32SrBf8Bf16:
781+
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
782+
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
783+
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
784+
let description = [{
785+
Scale `src` by the exponent in `scale` then convert to packed p8 with stochastic rounding
786+
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
787+
788+
}];
789+
let assemblyFormat = [{
790+
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
791+
}];
792+
}
793+
794+
def ROCDL_CvtScaleF32PkF16Fp8 :
795+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
796+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
797+
let summary = "Scale and convert fp8 to packed f16";
798+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
799+
then convert to packed f16.
800+
}];
801+
let assemblyFormat = [{
802+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
803+
}];
804+
}
805+
806+
def ROCDL_CvtScaleF32PkF16Bf8 :
807+
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
808+
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
809+
let summary = "Scale and convert bf8 to packed f16";
810+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
811+
then convert to packed f16.
812+
}];
813+
let assemblyFormat = [{
814+
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
815+
}];
816+
}
817+
818+
def ROCDL_CvtScaleF16Fp8 :
819+
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
820+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
821+
let summary = "Scale and convert fp8 to f16";
822+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
823+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
824+
}];
825+
let assemblyFormat = [{
826+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
827+
}];
828+
}
829+
830+
def ROCDL_CvtScaleF16Bf8 :
831+
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
832+
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
833+
let summary = "Scale and convert fp8 to f16";
834+
let description = [{ Scale `src` based on $wordSel by the exponent in `scale`
835+
then convert to f16 store into the `byteSel`th byte of `old`, preserving the others.
836+
}];
837+
let assemblyFormat = [{
838+
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
839+
}];
840+
}
841+
670842
//===---------------------------------------------------------------------===//
671843
// 32-bit float intrinsics
672844
//===---------------------------------------------------------------------===//
673-
def ROCDL_CvtScalePkF32Fp8 :
845+
def ROCDL_CvtScale32PkF32Fp8 :
674846
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
675847
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
676848
let summary = "Scale and convert packed fp8 to packed f32";
@@ -682,7 +854,7 @@ def ROCDL_CvtScalePkF32Fp8 :
682854
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
683855
}];
684856
}
685-
def ROCDL_CvtScalePkF32Bf8 :
857+
def ROCDL_CvtScale32PkF32Bf8 :
686858
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
687859
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
688860
let summary = "Scale and convert packed bf8 to packed f32";
@@ -697,10 +869,6 @@ def ROCDL_CvtScalePkF32Bf8 :
697869
//===---------------------------------------------------------------------===//
698870
// 8-bit float scale intrinsics
699871
//===---------------------------------------------------------------------===//
700-
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
701-
BuildableType<"::mlir::VectorType::get("
702-
"{2},$_builder.getI16Type())">;
703-
704872
def ROCDL_CvtScaleF32PkFp8F32:
705873
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
706874
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ profileComplianceMap = {
3535
{fp16T, fp16T, fp32T, fp32T},
3636
{fp32T, fp32T, fp32T, fp32T}}}}},
3737
{"tosa.matmul",
38-
{{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
38+
{{{Profile::pro_int}, {{i8T, i8T, i32T}}},
3939
{{Profile::pro_fp},
40-
{{fp16T, fp16T, fp16T, fp16T, fp16T},
41-
{fp16T, fp16T, fp16T, fp16T, fp32T},
42-
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
40+
{{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
4341
{"tosa.max_pool2d",
4442
{{{Profile::pro_int}, {{i8T, i8T}}},
4543
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -275,10 +273,10 @@ extensionComplianceMap = {
275273
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
276274
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
277275
{"tosa.matmul",
278-
{{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
279-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
280-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
281-
{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
276+
{{{Extension::int16}, {{i16T, i16T, i48T}}},
277+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
278+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
279+
{{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
282280
{"tosa.max_pool2d",
283281
{{{Extension::int16}, {{i16T, i16T}}},
284282
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
311311
let arguments = (ins
312312
Tosa_Tensor3D:$a,
313313
Tosa_Tensor3D:$b,
314-
Tosa_ScalarIntOrFloatTensor:$a_zp,
315-
Tosa_ScalarIntOrFloatTensor:$b_zp
314+
OptionalAttr<I32Attr>:$a_zp,
315+
OptionalAttr<I32Attr>:$b_zp
316316
);
317317

318318
let results = (outs
@@ -324,13 +324,6 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
324324
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
325325
];
326326

327-
let extraClassDeclaration = [{
328-
FailureOr<int64_t> getAZeroPoint();
329-
FailureOr<int64_t> getBZeroPoint();
330-
LogicalResult verifyAZeroPoint(int64_t zp);
331-
LogicalResult verifyBZeroPoint(int64_t zp);
332-
}];
333-
334327
let builders = [Tosa_MatMulOpQuantInfoBuilder];
335328
let hasVerifier = 1;
336329
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
270270
return rewriter.notifyMatchFailure(
271271
op, "weight zero point cannot be statically determined");
272272

273-
const int64_t inputZpVal = *maybeIZp;
274-
const int64_t weightZpVal = *maybeWZp;
273+
int64_t inputZpVal = *maybeIZp;
274+
int64_t weightZpVal = *maybeWZp;
275275

276276
if (op.verifyInputZeroPoint(inputZpVal).failed())
277277
return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
466466
return rewriter.notifyMatchFailure(
467467
op, "weight zero point cannot be statically determined");
468468

469-
const int64_t inputZpVal = *maybeIZp;
470-
const int64_t weightZpVal = *maybeWZp;
469+
int64_t inputZpVal = *maybeIZp;
470+
int64_t weightZpVal = *maybeWZp;
471471

472472
if (op.verifyInputZeroPoint(inputZpVal).failed())
473473
return rewriter.notifyMatchFailure(
@@ -621,38 +621,15 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
621621
.create<linalg::FillOp>(loc, ValueRange{zero},
622622
ValueRange{emptyTensor})
623623
.result();
624-
625-
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
626-
FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
627-
if (failed(maybeAZp))
628-
return rewriter.notifyMatchFailure(
629-
op, "input a zero point cannot be statically determined");
630-
if (failed(maybeBZp))
631-
return rewriter.notifyMatchFailure(
632-
op, "input b zero point cannot be statically determined");
633-
634-
const int64_t aZpVal = *maybeAZp;
635-
const int64_t bZpVal = *maybeBZp;
636-
637-
if (op.verifyAZeroPoint(aZpVal).failed())
638-
return rewriter.notifyMatchFailure(
639-
op, "input a zero point must be zero for non-int8 integer types");
640-
641-
if (op.verifyBZeroPoint(bZpVal).failed())
642-
return rewriter.notifyMatchFailure(
643-
op, "input b zero point must be zero for non-int8 integer types");
644-
645-
if (aZpVal == 0 && bZpVal == 0) {
624+
if (!op.getAZp() && !op.getBZp()) {
646625
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
647626
op, TypeRange{op.getType()},
648627
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
649628
return success();
650629
}
651630

652-
auto aZp = rewriter.create<arith::ConstantOp>(
653-
loc, rewriter.getI32IntegerAttr(aZpVal));
654-
auto bZp = rewriter.create<arith::ConstantOp>(
655-
loc, rewriter.getI32IntegerAttr(bZpVal));
631+
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
632+
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
656633
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
657634
op, TypeRange{op.getType()},
658635
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -857,8 +834,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
857834
return rewriter.notifyMatchFailure(
858835
op, "output zero point could not be statically determined");
859836

860-
const int64_t inputZpVal = *maybeIZp;
861-
const int64_t outputZpVal = *maybeOZp;
837+
int64_t inputZpVal = *maybeIZp;
838+
int64_t outputZpVal = *maybeOZp;
862839

863840
// Apply padding as necessary.
864841
llvm::SmallVector<int64_t> pad;

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ struct MatMulOpSharding
5555
SmallVector<AffineMap> maps;
5656
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
5757
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58-
maps.push_back(AffineMap::get(0, 0, {}, ctx));
59-
maps.push_back(AffineMap::get(0, 0, {}, ctx));
6058
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
6159
return maps;
6260
}

0 commit comments

Comments
 (0)