Skip to content

Commit f787eb7

Browse files
Merge commit 'b5fea1e3f4c2cb0b40c0ce98261b240d8728d2f9'
2 parents 39907eb + b5fea1e commit f787eb7

File tree

34 files changed

+1221
-423
lines changed

34 files changed

+1221
-423
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,7 @@ jobs:
181181
- name: Run Proton tests
182182
run: |
183183
unset HIP_VISIBLE_DEVICES
184-
unset ROCR_VISIBLE_DEVICES
185-
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ]; then
186-
python3 -m pytest -s -n 8 third_party/proton/test -k "not test_instrument_exec"
187-
else
188-
make test-proton
189-
fi
184+
make test-proton
190185
- name: Inspect cache directories
191186
run: |
192187
mkdir -p ~/.triton

.github/workflows/llvm-build.yml

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ jobs:
3131
config:
3232
- {runner: 'Ubuntu 22.04', runs_on: 'ubuntu-22.04', target-os: 'ubuntu', arch: 'x64'}
3333
- {runner: 'Ubuntu 22.04 ARM64', runs_on: 'ubuntu-22.04', target-os: 'ubuntu', arch: 'arm64'}
34-
- {runner: 'CentOS 7', runs_on: ['self-hosted', 'CPU'], target-os: 'centos', arch: 'x64'}
3534
- {runner: 'AlmaLinux 8', runs_on: ['self-hosted', 'CPU'], target-os: 'almalinux', arch: 'x64'}
3635
- {runner: 'AlmaLinux 8 ARM64', runs_on: 'ubuntu-22.04-arm', target-os: 'almalinux', arch: 'arm64'}
3736
- {runner: 'MacOS X64', runs_on: 'macos-13', target-os: 'macos', arch: 'x64'}
@@ -233,30 +232,6 @@ jobs:
233232
234233
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
235234
236-
237-
- name: Configure, Build, Test, and Install LLVM (CentOS)
238-
if: matrix.config.target-os == 'centos'
239-
run: |
240-
# if this step crashes, it can leave behind a stale docker container
241-
docker container prune -f
242-
docker rmi -f $(docker images -q)
243-
244-
docker build --tag llvm-build --build-arg llvm_dir=llvm-project \
245-
-f llvm-build/.github/workflows/llvm-build/centos.Dockerfile .
246-
247-
# Create temporary container to copy cache and installed artifacts.
248-
CONTAINER_ID=$(docker create llvm-build)
249-
docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}"
250-
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
251-
252-
# We remove the existing directory, otherwise docker will
253-
# create a subdirectory inside the existing directory.
254-
rm -rf "${{ env.SCCACHE_DIR }}"
255-
docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}"
256-
sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}"
257-
258-
docker rm "${CONTAINER_ID}"
259-
260235
- name: Configure, Build, Test, and Install LLVM (AlmaLinux)
261236
if: matrix.config.target-os == 'almalinux'
262237
run: |

.github/workflows/llvm-build/centos.Dockerfile

Lines changed: 0 additions & 56 deletions
This file was deleted.

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ if(POLICY CMP0116)
66
cmake_policy(SET CMP0116 OLD)
77
endif()
88

9-
include(ExternalProject)
10-
119
set(CMAKE_CXX_STANDARD 17)
1210

1311
set(CMAKE_INCLUDE_CURRENT_DIR ON)

lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
7373
}
7474
};
7575

76+
//===----------------------------------------------------------------------===//
77+
// FunctionOpInterfaceSignatureConversion
78+
//===----------------------------------------------------------------------===//
79+
// NOTE: Forked from mlir to support remapping argument attributes correctly in
80+
// a one-to-many type conversion.
81+
82+
SmallVector<Attribute>
83+
convertFuncOpAttrs(FunctionOpInterface funcOp,
84+
TypeConverter::SignatureConversion &sigConv,
85+
FunctionType newType) {
86+
if (newType.getNumInputs() == funcOp.getNumArguments()) {
87+
return {};
88+
}
89+
ArrayAttr allArgAttrs = funcOp.getAllArgAttrs();
90+
if (!allArgAttrs)
91+
return {};
92+
93+
SmallVector<Attribute> newAttrs(newType.getNumInputs());
94+
for (auto i : llvm::seq(allArgAttrs.size())) {
95+
auto mapping = sigConv.getInputMapping(i);
96+
assert(mapping.has_value());
97+
auto outIdx = mapping->inputNo;
98+
newAttrs[outIdx] = allArgAttrs[i];
99+
}
100+
return newAttrs;
101+
}
102+
103+
LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
104+
const TypeConverter &typeConverter,
105+
ConversionPatternRewriter &rewriter) {
106+
FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
107+
if (!type)
108+
return failure();
109+
110+
// Convert the original function types.
111+
TypeConverter::SignatureConversion result(type.getNumInputs());
112+
SmallVector<Type, 1> newResults;
113+
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
114+
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
115+
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
116+
typeConverter, &result)))
117+
return failure();
118+
119+
// Update the function signature in-place.
120+
auto newType = FunctionType::get(rewriter.getContext(),
121+
result.getConvertedTypes(), newResults);
122+
123+
auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType);
124+
125+
rewriter.modifyOpInPlace(funcOp, [&] {
126+
funcOp.setType(newType);
127+
if (!newArgAttrs.empty()) {
128+
funcOp.setAllArgAttrs(newArgAttrs);
129+
}
130+
});
131+
132+
return success();
133+
}
134+
135+
/// Create a default conversion pattern that rewrites the type signature of a
136+
/// FunctionOpInterface op. This only supports ops which use FunctionType to
137+
/// represent their type.
138+
struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
139+
FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
140+
MLIRContext *ctx,
141+
const TypeConverter &converter,
142+
PatternBenefit benefit = 1)
143+
: ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
144+
145+
LogicalResult
146+
matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
147+
ConversionPatternRewriter &rewriter) const override {
148+
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
149+
return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
150+
}
151+
};
152+
76153
} // namespace
77154

78155
void populateFunctionTypeConversions(const TypeConverter &converter,
79156
RewritePatternSet &patterns) {
80-
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81-
patterns, converter);
82-
patterns.add<CallOpConversion, ReturnOpConversion>(converter,
83-
patterns.getContext());
157+
auto context = patterns.getContext();
158+
patterns.add<FunctionOpInterfaceSignatureConversion>(
159+
triton::FuncOp::getOperationName(), context, converter);
160+
patterns.add<CallOpConversion, ReturnOpConversion>(converter, context);
84161
}
85162

86163
} // namespace mlir::triton

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -754,17 +754,19 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
754754
Value barrierAlloc = createBarrierAlloc(forOp, numStages);
755755
Value vTrue = builder.create<arith::ConstantIntOp>(1, 1);
756756
Value phase = forOp.getRegionIterArg(phaseArgIdx);
757-
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
758757
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
758+
Value barrierIdx;
759+
if (numStages > 1) {
760+
barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
761+
} else {
762+
barrierIdx = zero;
763+
}
759764
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
760765
Value numStagesVal =
761766
builder.create<arith::ConstantIntOp>(forOp.getLoc(), numStages, 32);
762767

763-
Value barrierSlice = barrierAlloc;
764-
if (numStages > 1) {
765-
barrierSlice =
766-
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
767-
}
768+
Value barrierSlice =
769+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
768770
mma.addCompletionBarrier(barrierSlice, vTrue);
769771
mma.setIsAsync(true);
770772

python/src/gluon_ir.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,12 @@ void init_gluon_ir(py::module &&m) {
812812
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(descPtr, indices,
813813
result, pred);
814814
})
815+
.def("create_async_tdm_copy_local_to_global",
816+
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
817+
Value src) {
818+
self.create<ttag::AsyncTDMCopyLocalToGlobalOp>(descPtr, indices,
819+
src);
820+
})
815821
.def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) {
816822
ValueRange tokens;
817823
self.create<ttag::AsyncTDMWait>(tokens, num);

python/test/gluon/test_core.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
TensorMemoryScalesLayout,
3131
allocate_tensor_memory,
3232
get_tmem_32x32b_reg_layout,
33+
get_tmem_scales_reg_layout,
3334
tcgen05_mma,
35+
tcgen05_mma_scaled,
3436
tcgen05_commit,
3537
tcgen05_copy,
3638
float2,
@@ -1334,3 +1336,92 @@ def kernel_auto_layout_constant(threads_per_warp: ttgl.constexpr):
13341336

13351337
def test_auto_layout_constant():
13361338
kernel_auto_layout_constant.warmup(THREADS_PER_WARP, grid=(1, ))
1339+
1340+
1341+
def fp8e8m0_to_float32(scale):
1342+
scale = scale.view(torch.uint8)
1343+
scale = scale.to(torch.int32)
1344+
scale = scale << 23
1345+
scale = scale.view(torch.float32)
1346+
return scale
1347+
1348+
1349+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
1350+
def test_tcgen05_mma_scaled_minimal():
1351+
M = 128
1352+
N = 128
1353+
K = 128
1354+
threads_per_warp = ttgl.constexpr(THREADS_PER_WARP)
1355+
1356+
@gluon.jit
1357+
def kernel(out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, a, b, a_scale, b_scale):
1358+
# Simple register layout for creating constants and storing results
1359+
reg_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [threads_per_warp, 1], [ttgl.num_warps(), 1], [1, 0])
1360+
1361+
# Shared-memory layouts for MMA operands
1362+
nvmma_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, transposed=False,
1363+
element_bitwidth=8, rank=2)
1364+
# Allocate zero operands in shared memory (values don't matter since scales are zero)
1365+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], warps_per_cta=[ttgl.num_warps(), 1],
1366+
order=[1, 0])
1367+
a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None]
1368+
a_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :]
1369+
b_offs_k = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None]
1370+
b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, block_layout))[None, :]
1371+
1372+
a_tile = ttgl.load(a + a_offs_m * K + a_offs_k)
1373+
b_tile = ttgl.load(b + b_offs_k * N + b_offs_n)
1374+
a_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [M, K], nvmma_layout, a_tile)
1375+
b_smem = ttgl.allocate_shared_memory(ttgl.float8e5, [K, N], nvmma_layout, b_tile)
1376+
1377+
# Accumulator in TMEM initialized to ones
1378+
acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout([M, N], col_stride=1)
1379+
tmem_reg_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(M, N, [M, N], ttgl.num_warps())
1380+
acc_init = ttgl.zeros([M, N], ttgl.float32, layout=tmem_reg_layout)
1381+
acc_tmem = allocate_tensor_memory(ttgl.float32, [M, N], acc_tmem_layout, acc_init)
1382+
1383+
# Zero scales in TMEM
1384+
scale_layout: ttgl.constexpr = TensorMemoryScalesLayout()
1385+
scale_reg_layout: ttgl.constexpr = get_tmem_scales_reg_layout(M, N, [M, N], ttgl.num_warps())
1386+
scale_offs_k = ttgl.arange(0, (K // 32), layout=ttgl.SliceLayout(0, scale_reg_layout))[None, :]
1387+
scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None]
1388+
scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, scale_reg_layout))[:, None]
1389+
a_scale_init = ttgl.load(a_scale + scale_offs_m * (K // 32) + scale_offs_k)
1390+
b_scale_init = ttgl.load(b_scale + scale_offs_n * (K // 32) + scale_offs_k)
1391+
a_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, a_scale_init)
1392+
b_scale_tmem = allocate_tensor_memory(ttgl.int8, [M, K // 32], scale_layout, b_scale_init)
1393+
1394+
# Issue a single scaled MMA and commit
1395+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
1396+
mbarrier.init(bar, count=1)
1397+
tcgen05_mma_scaled(a_smem, b_smem, acc_tmem, a_scale_tmem, b_scale_tmem, "e5m2", "e5m2", use_acc=True)
1398+
tcgen05_commit(bar)
1399+
mbarrier.wait(bar, phase=0)
1400+
1401+
# Load result from TMEM and store to global
1402+
out_reg = acc_tmem.load(tmem_reg_layout)
1403+
store_layout: ttgl.constexpr = reg_layout
1404+
offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, store_layout))[:, None]
1405+
offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, store_layout))[None, :]
1406+
offs = offs_m * N + offs_n
1407+
ttgl.store(out_ptr + offs, ttgl.convert_layout(out_reg, store_layout))
1408+
1409+
out = torch.empty((M, N), dtype=torch.float32, device="cuda")
1410+
a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2)
1411+
b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device="cuda").view(torch.float8_e5m2)
1412+
a_scale = torch.randint(64, 130, (M, K // 32), dtype=torch.uint8, device="cuda")
1413+
b_scale = torch.randint(64, 130, (N, K // 32), dtype=torch.uint8, device="cuda")
1414+
compiled = kernel[(1, )](out, M, N, K, a, b, a_scale, b_scale)
1415+
A = a.to(torch.float32)
1416+
B = b.to(torch.float32)
1417+
a_scale_f32 = fp8e8m0_to_float32(a_scale)
1418+
b_scale_f32 = fp8e8m0_to_float32(b_scale)
1419+
a_scale_f32 = a_scale_f32.repeat_interleave(32, dim=1)
1420+
b_scale_f32 = b_scale_f32.repeat_interleave(32, dim=1)
1421+
b_scale_f32 = b_scale_f32.T.contiguous()
1422+
A = A * a_scale_f32
1423+
B = B * b_scale_f32
1424+
ref = torch.matmul(A, B)
1425+
torch.testing.assert_close(out, ref, atol=1e-6, rtol=1e-6)
1426+
ttgir = compiled.asm["ttgir"]
1427+
assert "ttng.tc_gen5_mma_scaled" in ttgir

0 commit comments

Comments
 (0)