Skip to content

Commit e1b256f

Browse files
committed
Merge commit 'b220c76447831e169df2c2d67f950e97774b2cd3'
2 parents 1a94e46 + b220c76 commit e1b256f

File tree

13 files changed

+330
-115
lines changed

13 files changed

+330
-115
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ jobs:
166166
# Reenable test_functional_regression.py once it's fixed
167167
cd python/test/regression
168168
python3 -m pytest -s -n 8 ./test_cast_matmul.py
169+
- name: Run microbenchmark tests
170+
run: |
171+
python3 python/test/microbenchmark/launch_overhead.py
169172
- name: Run Proton tests
170173
run: |
171174
unset HIP_VISIBLE_DEVICES

.github/workflows/integration-tests-nvidia.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ jobs:
9898
run: make test-interpret
9999
- name: Run regression tests
100100
run: make test-regression
101+
- name: Run microbenchmark tests
102+
# Microbenchmark never fail but running them gives us an easy way to track performance changes.
103+
run: make test-microbenchmark
101104
- name: Run C++ unittests
102105
run: make test-cpp
103106
- name: Run Proton tests

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ test-gluon: all
6060
test-regression: all
6161
$(PYTEST) -s -n $(NUM_PROCS) python/test/regression
6262

63+
.PHONY: test-microbenchmark
64+
test-microbenchmark: all
65+
$(PYTHON) python/test/microbenchmark/launch_overhead.py
66+
6367
.PHONY: test-interpret
6468
test-interpret: all
6569
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \

python/src/gluon_ir.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ py::object layoutToGluon(Attribute layout) {
217217

218218
return layouts.AMDMFMALayout(
219219
amdMfma.getVersion(), instrShape, amdMfma.getIsTransposed(),
220-
toStdVector(amdMfma.getWarpsPerCTA()),
221-
toStdVector(amdMfma.getTilesPerWarp()), layouts.GluonDType(typeName),
220+
toStdVector(amdMfma.getWarpsPerCTA()), layouts.GluonDType(typeName),
221+
toStdVector(amdMfma.getTilesPerWarp()),
222222
toStdVector(ctaLayout.getCTAsPerCGA()),
223223
toStdVector(ctaLayout.getCTASplitNum()),
224224
toStdVector(ctaLayout.getCTAOrder()));
@@ -325,13 +325,12 @@ void init_gluon_ir(py::module &&m) {
325325
})
326326
.def("get_amd_mfma_layout",
327327
[](GluonOpBuilder &self, unsigned version,
328+
std::vector<unsigned> &instrShape, bool transposed,
329+
std::vector<unsigned> &warpsPerCta, mlir::Type elemType,
328330
std::vector<unsigned> &tilesPerWarp,
329-
std::vector<unsigned> &warpsPerCta,
330331
std::vector<unsigned> &ctasPerCga,
331332
std::vector<unsigned> &ctaSplitNum,
332-
std::vector<unsigned> &ctaOrder,
333-
std::vector<unsigned> &instrShape, bool transposed,
334-
mlir::Type elemType) -> Attribute {
333+
std::vector<unsigned> &ctaOrder) -> Attribute {
335334
auto ctx = self.getContext();
336335
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
337336
ctx, ctasPerCga, ctaSplitNum, ctaOrder);

python/test/gluon/test_core.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer, is_hopper
4+
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
77
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
@@ -143,3 +143,66 @@ def test_warpgroup_mma(ASYNC):
143143
ref = torch.matmul(a, b)
144144

145145
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
146+
147+
148+
@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)])
149+
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
150+
@pytest.mark.parametrize("num_warps", [4, 8])
151+
@pytest.mark.parametrize("cdna_version", [3, 4])
152+
def test_amd_mfma(M, N, K, in_dtype, num_warps, cdna_version):
153+
154+
@gluon.jit
155+
def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
156+
stride_bk, stride_bn, #
157+
stride_cm, stride_cn, BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr,
158+
BLOCK_SIZE_K: ttgl.constexpr, blocked: ttgl.constexpr, mfma_layout: ttgl.constexpr):
159+
dot_a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=8)
160+
dot_b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=8)
161+
162+
offs_am = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
163+
offs_bn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
164+
165+
offs_ak = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(0, blocked))
166+
offs_bk = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(1, blocked))
167+
offs_a = offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak
168+
offs_b = offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn
169+
170+
a = ttgl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=offs_a)
171+
b = ttgl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=offs_b)
172+
a1 = ttgl.convert_layout(a, layout=dot_a_layout)
173+
b1 = ttgl.convert_layout(b, layout=dot_b_layout)
174+
acc = ttgl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], ttgl.float32, mfma_layout)
175+
c = ttgl.amd.cdna3.mfma(a1, b1, acc)
176+
c = ttgl.convert_layout(c, layout=blocked)
177+
c = c.to(a_ptr.dtype.element_ty)
178+
179+
offs_cm = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
180+
offs_cn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
181+
offs_c = offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
182+
ttgl.amd.cdna3.buffer_store(stored_value=c, ptr=c_ptr, offsets=offs_c)
183+
184+
if not is_hip_cdna4() and not is_hip_cdna3():
185+
pytest.skip("mfma quires target to be CDNA3 or CDNA4")
186+
187+
if is_hip_cdna3() and cdna_version != 3:
188+
pytest.skip("On CDNA3 target, skip if mfma version is not 3")
189+
190+
if is_hip_cdna4() and cdna_version != 4:
191+
pytest.skip("On CDNA4 target, skip if mfma version is not 4")
192+
193+
elem_type = torch.float16 if in_dtype == 'float16' else torch.bfloat16
194+
a = torch.randn((M, K), device='cuda', dtype=elem_type) - 0.5
195+
b = torch.randn((K, N), device='cuda', dtype=elem_type) - 0.5
196+
c = torch.empty((M, N), device=a.device, dtype=elem_type)
197+
nonkdim: ttgl.constexpr = 32
198+
blocked: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[4, 4], threads_per_warp=[4, 16],
199+
warps_per_cta=[num_warps, 1], order=[1, 0])
200+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=cdna_version, instr_shape=[nonkdim, nonkdim],
201+
transposed=True, warps_per_cta=[num_warps, 1])
202+
203+
kernel[1, 1](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=M,
204+
BLOCK_SIZE_N=N, BLOCK_SIZE_K=K, blocked=blocked, mfma_layout=mfma_layout, num_warps=num_warps)
205+
206+
ref = torch.matmul(a, b)
207+
triton_output = c
208+
torch.testing.assert_close(ref, triton_output)

python/test/gluon/test_frontend.py

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,31 +1413,27 @@ def test_atomic_cas():
14131413

14141414
@gluon.jit
14151415
def amd_mfma_layout_kernel():
1416-
mfma_layout_fp32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1417-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1418-
ctas_per_cga=[1,
1419-
1], cta_split_num=[1,
1420-
1], cta_order=[1, 0])
1416+
ttgl.full([128, 32], 0, ttgl.float32, layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32],
1417+
transposed=True, warps_per_cta=[4, 1]))
14211418

1422-
mfma_layout_fp64: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True,
1423-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1424-
elem_type=ttgl.float64, ctas_per_cga=[1, 1],
1425-
cta_split_num=[1, 1], cta_order=[1, 0])
1419+
ttgl.full([128, 32], 0, ttgl.float32,
1420+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], tiles_per_warp=[4, 1], transposed=True,
1421+
warps_per_cta=[4, 1]))
14261422

1427-
mfma_layout_int32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True,
1428-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1429-
elem_type=ttgl.int32, ctas_per_cga=[1, 1],
1430-
cta_split_num=[1, 1], cta_order=[1, 0])
1423+
ttgl.full([128, 32], 0, ttgl.float32,
1424+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True, warps_per_cta=[4, 1],
1425+
ctas_per_cga=[1, 1], tiles_per_warp=[1, 1], cta_split_num=[1, 1],
1426+
cta_order=[1, 0]))
14311427

1432-
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0])
1428+
ttgl.full([128, 32], 0, ttgl.float64,
1429+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True, warps_per_cta=[4, 1],
1430+
elem_type=ttgl.float64, tiles_per_warp=[1, 1], ctas_per_cga=[1, 1],
1431+
cta_split_num=[1, 1], cta_order=[1, 0]))
14331432

1434-
x_fp32 = ttgl.full([128, 32], 0, ttgl.float32, layout)
1435-
x_fp64 = ttgl.full([128, 32], 0, ttgl.float64, layout)
1436-
x_int32 = ttgl.full([128, 32], 0, ttgl.int32, layout)
1437-
1438-
ttgl.convert_layout(x_fp32, mfma_layout_fp32)
1439-
ttgl.convert_layout(x_fp64, mfma_layout_fp64)
1440-
ttgl.convert_layout(x_int32, mfma_layout_int32)
1433+
ttgl.full([128, 32], 0, ttgl.int32,
1434+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True, warps_per_cta=[4, 1],
1435+
elem_type=ttgl.int32, tiles_per_warp=[1, 1], ctas_per_cga=[1, 1],
1436+
cta_split_num=[1, 1]))
14411437

14421438

14431439
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
@@ -1446,21 +1442,22 @@ def test_amd_mfma_layout(target):
14461442
module = run_parser(amd_mfma_layout_kernel, target=target)
14471443
expecttest.assert_expected_inline(
14481444
anonymize_ir(module.str_nodebug()), """\
1449-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1450-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1451-
#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1452-
#mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
1445+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
1446+
#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1447+
#mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1448+
#mma3 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
14531449
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14541450
tt.func public @amd_mfma_layout_kernel() attributes {noinline = false} {
14551451
%cst = arith.constant 0.000000e+00 : f32
1456-
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
1457-
%cst_1 = arith.constant 0.000000e+00 : f64
1458-
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #blocked>
1452+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1453+
%cst_1 = arith.constant 0.000000e+00 : f32
1454+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma1>
1455+
%cst_3 = arith.constant 0.000000e+00 : f32
1456+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1457+
%cst_5 = arith.constant 0.000000e+00 : f64
1458+
%cst_6 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #mma2>
14591459
%c0_i32 = arith.constant 0 : i32
1460-
%cst_3 = arith.constant dense<0> : tensor<128x32xi32, #blocked>
1461-
%0 = ttg.convert_layout %cst_0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #mma>
1462-
%1 = ttg.convert_layout %cst_2 : tensor<128x32xf64, #blocked> -> tensor<128x32xf64, #mma1>
1463-
%2 = ttg.convert_layout %cst_3 : tensor<128x32xi32, #blocked> -> tensor<128x32xi32, #mma2>
1460+
%cst_7 = arith.constant dense<0> : tensor<128x32xi32, #mma3>
14641461
tt.return
14651462
}
14661463
}
@@ -1475,8 +1472,8 @@ def add_int(a, b):
14751472
@gluon.jit
14761473
def infer_layout_for_amd_mfma_kernel():
14771474
layout: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1478-
elem_type=ttgl.int32, warps_per_cta=[4,
1479-
1], tiles_per_warp=[4, 1],
1475+
warps_per_cta=[4,
1476+
1], elem_type=ttgl.int32, tiles_per_warp=[1, 1],
14801477
ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])
14811478
a = ttgl.full([128, 32], 1, ttgl.int32, layout)
14821479
b = ttgl.reduce(a, 1, add_int)
@@ -1489,7 +1486,7 @@ def test_infer_layout_for_amd_mfma(target):
14891486

14901487
expecttest.assert_expected_inline(
14911488
anonymize_ir(module.str_nodebug()), """\
1492-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
1489+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
14931490
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14941491
tt.func public @infer_layout_for_amd_mfma_kernel() attributes {noinline = false} {
14951492
%c1_i32 = arith.constant 1 : i32
@@ -1719,3 +1716,49 @@ def test_buffer_load_store_with_broadcast(target):
17191716
}
17201717
}
17211718
""")
1719+
1720+
1721+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1722+
def test_amd_mfma(target):
1723+
1724+
@gluon.jit
1725+
def kernel():
1726+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1727+
warps_per_cta=[4, 1])
1728+
1729+
a = ttgl.full([64, 32], 1.0, ttgl.float32, layout=ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout,
1730+
k_width=8))
1731+
b = ttgl.full([32, 64], 2.0, ttgl.float32, layout=ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout,
1732+
k_width=8))
1733+
1734+
acc = ttgl.zeros([64, 64], ttgl.float32, mfma_layout)
1735+
acc = ttgl.amd.cdna3.mfma(a, b, acc)
1736+
ttgl.static_assert(isinstance(acc, ttgl.tensor))
1737+
ttgl.static_assert(acc.type.layout == mfma_layout)
1738+
1739+
module = run_parser(kernel, target=target)
1740+
1741+
expecttest.assert_expected_inline(
1742+
anonymize_ir(module.str_nodebug()), """\
1743+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
1744+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1745+
tt.func public @kernel() attributes {noinline = false} {
1746+
%cst = arith.constant 1.000000e+00 : f32
1747+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
1748+
%cst_1 = arith.constant 2.000000e+00 : f32
1749+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
1750+
%0 = tt.call @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() : () -> tensor<64x64xf32, #mma>
1751+
%cst_3 = arith.constant 0.000000e+00 : f32
1752+
%1 = tt.dot %cst_0, %cst_2, %0 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x64xf32, #mma>
1753+
tt.return
1754+
}
1755+
tt.func private @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() -> tensor<64x64xf32, #mma> attributes {noinline = false} {
1756+
%cst = arith.constant 0.000000e+00 : f32
1757+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
1758+
tt.return %cst_0 : tensor<64x64xf32, #mma>
1759+
^bb1: // no predecessors
1760+
%0 = ub.poison : tensor<64x64xf32, #mma>
1761+
tt.return %0 : tensor<64x64xf32, #mma>
1762+
}
1763+
}
1764+
""")
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
Original code by @bertmaher; profiling added by @apgoucher
3+
"""
4+
5+
import cProfile
6+
import pstats
7+
import time
8+
9+
import numpy as np
10+
import torch
11+
12+
import triton
13+
import triton.language as tl
14+
15+
16+
@triton.jit
17+
def nop_args(
18+
t1,
19+
t2,
20+
t3,
21+
t4,
22+
t5,
23+
i1,
24+
i2,
25+
i3,
26+
i4,
27+
i5,
28+
i6,
29+
i7,
30+
i8,
31+
i9,
32+
c1: tl.constexpr,
33+
c2: tl.constexpr,
34+
c3: tl.constexpr,
35+
c4: tl.constexpr,
36+
c5: tl.constexpr,
37+
):
38+
pass
39+
40+
41+
def do_bench_walltime(fn):
42+
print("Compiling...")
43+
fn()
44+
torch.cuda.synchronize()
45+
46+
for _ in range(1000):
47+
fn()
48+
torch.cuda.synchronize()
49+
50+
n_repeat = 10000
51+
52+
mses = []
53+
54+
for _ in range(25):
55+
print("Running %d benchmarking iterations..." % n_repeat)
56+
# Benchmark
57+
torch.cuda.synchronize()
58+
start_time = time.time()
59+
for _ in range(n_repeat):
60+
fn()
61+
torch.cuda.synchronize()
62+
end_time = time.time()
63+
wall_time_ms = (end_time - start_time) * 1e3 / n_repeat
64+
mses.append(wall_time_ms)
65+
66+
mses = np.array(mses)
67+
68+
print("Running profiler...")
69+
profile = cProfile.Profile()
70+
profile.enable()
71+
for _ in range(n_repeat):
72+
fn()
73+
torch.cuda.synchronize()
74+
profile.disable()
75+
stats = pstats.Stats(profile)
76+
stats.sort_stats("time")
77+
stats.print_stats()
78+
return mses
79+
80+
81+
def main():
82+
targs = [torch.zeros(1, device="cuda") for _ in range(5)]
83+
iargs = [1 for _ in range(9)]
84+
cargs = [32 for _ in range(5)]
85+
86+
usecs = do_bench_walltime(lambda: nop_args[
87+
1,
88+
](*targs, *iargs, *cargs)) * 1000.0
89+
90+
print(usecs)
91+
print(sorted(usecs)[len(usecs) >> 1])
92+
93+
94+
if __name__ == "__main__":
95+
main()

python/test/unit/runtime/test_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ def test_is_lazy():
1010
from importlib import reload
1111
reload(sys.modules["triton.runtime.driver"])
1212
reload(sys.modules["triton.runtime"])
13-
mod = sys.modules[triton.runtime.driver.__module__]
14-
assert isinstance(triton.runtime.driver.active, getattr(mod, "LazyProxy"))
15-
assert triton.runtime.driver.active._obj is None
13+
assert triton.runtime.driver._active is None
14+
assert triton.runtime.driver._default is None
15+
assert isinstance(triton.runtime.driver.active, getattr(triton.backends.driver, "DriverBase"))
16+
assert isinstance(triton.runtime.driver.default, getattr(triton.backends.driver, "DriverBase"))
1617
utils = triton.runtime.driver.active.utils # noqa: F841
17-
assert issubclass(triton.runtime.driver.active._obj.__class__, getattr(triton.backends.driver, "DriverBase"))
1818

1919

2020
def test_kernel_in_thread(device):

0 commit comments

Comments
 (0)