Skip to content

Commit 328fd8a

Browse files
Merge commit 'cda4229558c5dca7f7c4734bedd3e596ebcae0b8'
2 parents 09ebc1a + cda4229 commit 328fd8a

File tree

7 files changed

+40
-17
lines changed

7 files changed

+40
-17
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,20 @@ def test_list_of_functions():
109109
# CHECK-NEXT: call @anchor
110110
# CHECK-NEXT: call @forward
111111
list_of_functions_constexpr(tl.arange(0, 4), [anchor, forward])
112+
113+
114+
@triton.jit
115+
def accumulate(a, b):
116+
return a + b
117+
118+
119+
# Check that we can call a function returning a value from a loop.
120+
@filecheck_test
121+
@triton.jit
122+
def test_call_in_loop():
123+
# CHECK-LABEL: test_call_in_loop
124+
acc = 0
125+
# CHECK: scf.for
126+
# CHECK: call @accumulate
127+
for i in range(10):
128+
acc = accumulate(acc, i)

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,9 @@ def _visit_stmts(self, body) -> bool:
135135
return any(self.visit(s) for s in body)
136136

137137
def _visit_function(self, fn) -> bool:
138-
# Currently we only support JITFunctions defined in the global scope
139-
if isinstance(fn, JITFunction) and not fn.noinline:
140-
fn_node = fn.parse()
141-
return ContainsReturnChecker(self.gscope).visit(fn_node)
138+
# no need to check within the function as it won't cause an early return.
139+
# If the function itself has unstructured control flow we may not be able to inline it causing poor performance.
140+
# We should check for this and fail or emit a warning.
142141
return False
143142

144143
def generic_visit(self, node) -> bool:

python/triton_kernels/tests/test_routing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def ref_expt_data(routing_data, n_gates, block_m):
4444
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)])
4545
@pytest.mark.parametrize("block_m", [64, 128])
4646
@pytest.mark.parametrize("use_expt_indx", [False, True])
47-
def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, use_expt_indx, device):
47+
@pytest.mark.parametrize("renormalize", [True, False])
48+
def test_op(n_tokens, n_expts_tot, n_expts_act, renormalize, block_m, use_expt_indx, device):
4849
torch.manual_seed(2)
4950
tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach()
5051
ref_logits = tri_logits.clone()
@@ -55,8 +56,11 @@ def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, use_expt_indx, device):
5556
ref_expt_indx = tri_expt_indx[:n_tokens]
5657
else:
5758
tri_expt_indx = ref_expt_indx = None
58-
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, ref_expt_indx)
59-
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, tri_expt_indx)
59+
if not renormalize:
60+
tri_logits = torch.softmax(tri_logits, dim=-1)
61+
ref_logits = torch.softmax(ref_logits, dim=-1)
62+
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, renormalize, ref_expt_indx)
63+
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, renormalize, tri_expt_indx)
6064
ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m)
6165
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
6266

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,6 @@ def _p_matmul_ogs(
514514
if SWAP_XW:
515515
acc_tile = acc_tile.T
516516
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
517-
acc_tile *= gammas[:, None]
518517
if out_alpha is not None:
519518
acc_tile *= out_alpha
520519

@@ -525,6 +524,8 @@ def _p_matmul_ogs(
525524
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
526525
out = acc_tile
527526

527+
out *= gammas[:, None]
528+
528529
if MASK_ACC:
529530
out = tl.where(mask_m[:, None], out, 0.0)
530531
# Flexpoint

python/triton_kernels/triton_kernels/routing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def n_blocks(self, n_rows, block_m):
5353
# --------------------------
5454

5555

56-
def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
56+
def routing(logits, n_expts_act, renormalize=True, expt_indx=None, simulated_ep=1):
5757
from .topk import topk
5858
from .compaction import compaction
5959
cdiv = triton.cdiv
@@ -63,7 +63,7 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
6363
n_tokens, n_expts_tot = logits.shape
6464
n_gates = n_tokens * n_expts_act
6565
device = logits.device
66-
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, y_indx=expt_indx)
66+
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, apply_softmax=renormalize, y_indx=expt_indx)
6767
# mutate bitmatrix
6868
if simulated_ep > 1:
6969
assert n_expts_tot % simulated_ep == 0
@@ -108,7 +108,7 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
108108
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx
109109

110110

111-
def routing_torch(logits, n_expts_act, expt_indx=None):
111+
def routing_torch(logits, n_expts_act, renormalize=True, expt_indx=None):
112112

113113
def topk(vals, k, expt_indx):
114114
# topk of experts
@@ -121,7 +121,8 @@ def topk(vals, k, expt_indx):
121121

122122
_, n_expts_tot = logits.shape
123123
expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx)
124-
expt_scal = torch.softmax(expt_scal, dim=-1)
124+
if renormalize:
125+
expt_scal = torch.softmax(expt_scal, dim=-1)
125126
# flatten topk data
126127
expt_scal = expt_scal.reshape(-1)
127128
expt_indx = expt_indx.reshape(-1).to(torch.int32)

python/triton_kernels/triton_kernels/topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .bitmatrix import Bitmatrix
44

55

6-
def topk(x, k, dim=1, return_bitmatrix=True, y_indx=None):
6+
def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None):
77
cdiv = lambda a, b: (a + b - 1) // b
88
BLOCK_M = 32
99
BLOCK_N = 32
@@ -39,5 +39,5 @@ def topk(x, k, dim=1, return_bitmatrix=True, y_indx=None):
3939
S, BLOCK_S, s_blocks, # thing to memset to zero
4040
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
4141
N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
42-
)
42+
APPLY_SOFTMAX=apply_softmax)
4343
return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols], S)

python/triton_kernels/triton_kernels/topk_details/_topk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def _topk(X, stride_xm, # inputs
7272
Yv, Yi, stride_ym, # topk values/indices
7373
USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, n_rows, # bitmatrix
7474
n_expts_tot, S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
75-
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
75+
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr,
76+
APPLY_SOFTMAX: tl.constexpr):
7677

7778
pid = tl.program_id(0)
7879

@@ -105,8 +106,8 @@ def _topk(X, stride_xm, # inputs
105106
y_indices = y & 0x0000FFFF
106107
y_values = (y >> x_nbits).to(x_utype).to(x_dtype, bitcast=True)
107108

108-
# normalize selected values
109-
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
109+
if APPLY_SOFTMAX:
110+
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
110111

111112
# write back
112113
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]

0 commit comments

Comments
 (0)