Skip to content

Commit 63ab535

Browse files
authored
Merge branch 'main' into liyang/upcast_mxfp_and_dot_scaled
2 parents 06aa7ef + 7ac1225 commit 63ab535

File tree

8 files changed

+17
-20
lines changed

8 files changed

+17
-20
lines changed

.github/workflows/auto-update-translator-cid.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
6767
- name: Get commit ID from Triton's spirv-llvm-translator.conf
6868
if: ${{ env.TARGET_PRID == null }}
69-
run: echo "CID_CURRENT=$(cat ./lib/Target/SPIRV/spirv-llvm-translator.conf)" >> $GITHUB_ENV
69+
run: echo "CID_CURRENT=$(<third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf)" >> $GITHUB_ENV
7070

7171
- name: Checkout SPIRV-LLVM-Translator
7272
if: ${{ env.TARGET_PRID == null }}
@@ -82,12 +82,12 @@ jobs:
8282
run: |
8383
echo "CID_LATEST=$(git -C external/SPIRV-LLVM-Translator rev-parse HEAD)" >> $GITHUB_ENV
8484
85-
# the latest valid cid has been weitten to spirv-llvm-translator.conf
85+
# the latest valid cid has been written to spirv-llvm-translator.conf
8686
- name: Search the latest valid Translator cid
8787
if: ${{ env.TARGET_PRID == null }}
8888
run: |
8989
./scripts/check-update-translator-cid.sh $CID_LATEST $CID_CURRENT
90-
if git status --porcelain ./lib/Target/SPIRV/spirv-llvm-translator.conf | grep '^ M'; then
90+
if git status --porcelain third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf | grep '^ M'; then
9191
echo "MODIFIED=true" >> $GITHUB_ENV
9292
echo "spirv-llvm-translator.conf has been modified"
9393
fi
@@ -112,7 +112,7 @@ jobs:
112112
git checkout -b ${PR_BRANCH}
113113
git branch --show-current # bot/update_translator_cid
114114
git status
115-
git add ./lib/Target/SPIRV/spirv-llvm-translator.conf
115+
git add third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf
116116
git commit -m "Update spirv-llvm-translator.conf"
117117
git push origin ${PR_BRANCH}
118-
pr_url=$(gh pr create --title "[github-bot] Update spirv-llvm-translator.conf" --body "Automated PR to update translator commit id." --reviewer whitneywhtsang --head ${PR_BRANCH} --base main)
118+
gh pr create --title "[github-bot] Update spirv-llvm-translator.conf" --body "Automated PR to update translator commit id." --reviewer whitneywhtsang --head ${PR_BRANCH} --base main

.github/workflows/triton-benchmarks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ jobs:
202202
python gemm_streamk_benchmark.py --reports $REPORTS
203203
source ../../scripts/capture-hw-details.sh
204204
python ../../scripts/build_report.py $REPORTS/matmul-streamk-performance.csv $REPORTS/gemm-streamk-triton-report.csv --benchmark gemm-streamk --compiler triton --param_cols "M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
205+
python ../../scripts/build_report.py $REPORTS/matmul-streamk-performance.csv $REPORTS/gemm-streamk-xetla-report.csv --benchmark gemm-streamk --compiler xetla --param_cols "M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
205206
206207
- name: Run Triton GEMM (split-k) kernel benchmark
207208
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_splitk_benchmark.py') }}
@@ -210,6 +211,7 @@ jobs:
210211
python gemm_splitk_benchmark.py --reports $REPORTS
211212
source ../../scripts/capture-hw-details.sh
212213
python ../../scripts/build_report.py $REPORTS/matmul-splitk-performance.csv $REPORTS/gemm-splitk-triton-report.csv --benchmark gemm-splitk --compiler triton --param_cols "M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
214+
python ../../scripts/build_report.py $REPORTS/matmul-splitk-performance.csv $REPORTS/gemm-splitk-xetla-report.csv --benchmark gemm-splitk --compiler xetla --param_cols "M,K,N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
213215
214216
- name: Run Triton GEMM + PreOp (exp) kernel benchmark
215217
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_preop_exp_benchmark.py') }}

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
1515
}
1616

1717
// Floating-point Type
18-
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
18+
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
1919
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
2020
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
2121

python/src/ir.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -807,11 +807,6 @@ void init_triton_ir(py::module &&m) {
807807
[](TritonOpBuilder &self) -> Type {
808808
return self.getBuilder().getI8Type();
809809
})
810-
.def("get_fp8e4m3b11fnuz_ty",
811-
[](TritonOpBuilder &self) -> Type {
812-
// TODO: align with upstream code to use i8
813-
return self.getBuilder().getType<Float8E4M3B11FNUZType>();
814-
})
815810
.def("get_fp8e5_ty",
816811
[](TritonOpBuilder &self) -> Type {
817812
return self.getBuilder().getType<Float8E5M2Type>();
@@ -1680,11 +1675,11 @@ void init_triton_ir(py::module &&m) {
16801675
if (haveDiagnostics) {
16811676
context->printOpOnDiagnostic(true);
16821677
context->printStackTraceOnDiagnostic(true);
1678+
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
1679+
llvm::outs() << diag << "\n";
1680+
return success();
1681+
});
16831682
}
1684-
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
1685-
llvm::outs() << diag << "\n";
1686-
return success();
1687-
});
16881683
if (haveDump) {
16891684
auto printingFlags = OpPrintingFlags();
16901685
printingFlags.elideLargeElementsAttrs(16);

scripts/check-update-translator-cid.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ COMMIT_IDS=$(git -C $TRITON_PROJ/external/SPIRV-LLVM-Translator log --format="%H
2525
cd $TRITON_PROJ
2626
FOUND=false
2727
for cid in $COMMIT_IDS; do
28-
echo "$cid" > ./lib/Target/SPIRV/spirv-llvm-translator.conf
28+
echo "$cid" > third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf
2929

3030
BUILD_STATUS=PASS
3131
echo "::group::Building Triton for $cid"
@@ -52,5 +52,5 @@ for cid in $COMMIT_IDS; do
5252
done
5353

5454
if [ "$FOUND" = false ]; then
55-
git restore ./lib/Target/SPIRV/spirv-llvm-translator.conf
55+
git restore third_party/intel/lib/Target/SPIRV/spirv-llvm-translator.conf
5656
fi

third_party/intel/language/intel/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def num_warps(_builder=None):
2525

2626
def convert_fp8e4b15_to_float16(arg, _builder):
2727
# Need to bitcast the source first because it's represented as tensor of i8 in MLIR.
28-
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4m3b11fnuz_ty(), arg.type.shape)
28+
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4b8_ty(), arg.type.shape)
2929
tmp = _builder.create_bitcast(arg.handle, tmp_ty)
3030
# Now generate FpToFp op for upcast.
3131
dst_ty = core.block_type(core.float16, arg.type.get_block_shapes())
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
c1828a93d050afe8a44cc239a2908270c578c345
1+
50af2c70832745be1c0ce1562c4038162c319e5c

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ struct FpToFpOpConversion
974974
std::pair<ConverterT, size_t>
975975
getConversionFunc(Type srcTy, Type dstTy,
976976
std::optional<RoundingMode> roundingMode) const {
977-
auto F8E4M3B15TyID = TypeID::get<Float8E4M3B11FNUZType>();
977+
auto F8E4M3B15TyID = TypeID::get<Float8E4M3FNUZType>();
978978
auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
979979
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
980980
auto F16TyID = TypeID::get<Float16Type>();

0 commit comments

Comments
 (0)