Skip to content

Commit ca78cd2

Browse files
authored
Fix fp16 var_mean multi-output issue (#357)
Fixes #344
1 parent d49641e commit ca78cd2

File tree

4 files changed

+197
-18
lines changed

4 files changed

+197
-18
lines changed

helion/_compiler/inductor_lowering.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,16 @@ def convert_arg(arg: Node) -> TensorBox:
245245
if len(result) > 1 and nodes:
246246
last_node = nodes[-1] # The last node is the main node
247247
output_nodes = {}
248+
extra_deps = []
248249
for n in nodes:
249250
if "output_index" in n.meta:
250251
output_nodes[n.meta["output_index"]] = n.name
252+
if n is not last_node and n not in last_node._input_nodes:
253+
extra_deps.append(n)
251254
last_node.meta["output_nodes"] = output_nodes
255+
if extra_deps:
256+
# Need to ensure that the last node depends on all output nodes to prevent DCE issues
257+
last_node.kwargs = {**last_node.kwargs, "_extra_deps": extra_deps}
252258

253259

254260
def strip_unused_inputs(
@@ -371,7 +377,8 @@ def visit(n: torch.fx.Node) -> None:
371377
device_function: DeviceFunction = ctx.cg.device_function
372378
ndim: int = max([x.ndim for x in self.input_fake_tensors(node)] or (0,))
373379
input_asts: list[ast.AST] = []
374-
map_arg((node.args, node.kwargs), visit)
380+
# _extra_deps should not be included in the inductor node inputs
381+
map_arg((node.args, {**node.kwargs, "_extra_deps": None}), visit)
375382
assert len(input_asts) == len(self.input_names)
376383
return input_asts
377384

@@ -411,9 +418,7 @@ def install_inductor_kernel_handlers(
411418
"split_reductions": False,
412419
}
413420
),
414-
V.set_graph_handler(
415-
GraphLowering(dummy_gm(), shape_env=CompileEnvironment.current().shape_env)
416-
),
421+
V.set_graph_handler(FakeGraphLowering()),
417422
V.set_ops_handler(
418423
GenerateASTFromInductor(
419424
cg,
@@ -432,6 +437,14 @@ def dummy_gm() -> torch.fx.GraphModule:
432437
return torch.fx.symbolic_trace(lambda: None)
433438

434439

440+
class FakeGraphLowering(GraphLowering):
441+
def __init__(self) -> None:
442+
env = CompileEnvironment.current()
443+
super().__init__(dummy_gm(), shape_env=env.shape_env)
444+
# Set the device directly on the graph_lowering to ensure get_current_device_or_throw() works
445+
self._current_device = env.device
446+
447+
435448
class PointwiseLowering(InductorLowering):
436449
def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
437450
with self.install_kernel_handlers(ctx, node):

helion/_compiler/reduction_strategy.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch._inductor.ir import get_reduction_combine_fn
1212
from torch._inductor.runtime.runtime_utils import next_power_of_2
1313
from torch._inductor.utils import triton_type
14+
from torch._prims_common import get_computation_dtype
1415

1516
from ..autotuner.config_fragment import integer_power_of_two
1617
from .ast_extension import create
@@ -292,22 +293,23 @@ def codegen_reduction(
292293
fake_input: torch.Tensor,
293294
fake_output: torch.Tensor,
294295
) -> ast.AST:
295-
device_loop = state.codegen.active_device_loops[self.block_index][-1]
296-
assert isinstance(device_loop, DeviceLoopState)
297-
shape = self.fn.tile_strategy.shape_str([*fake_input.size()])
298-
default = ir.Reduction.default_accumulator(reduction_type, fake_input.dtype)
299-
assert isinstance(default, (float, int, bool))
300-
assert state.fx_node is not None
301-
acc = self.fn.new_var(f"{state.fx_node.name}_acc", dce=True)
302-
device_loop.outer_prefix.append(
303-
statement_from_string(
304-
f"{acc} = tl.full({shape}, {constant_repr(default)}, {triton_acc_type(fake_input.dtype)})"
305-
)
306-
)
307-
result = self.fn.new_var(state.fx_node.name, dce=True)
308296
with install_inductor_kernel_handlers(state.codegen, {}):
297+
device_loop = state.codegen.active_device_loops[self.block_index][-1]
298+
assert isinstance(device_loop, DeviceLoopState)
299+
shape = self.fn.tile_strategy.shape_str([*fake_input.size()])
300+
acc_dtype = get_computation_dtype(fake_input.dtype) # promote fp16 to fp32
301+
default = ir.Reduction.default_accumulator(reduction_type, acc_dtype)
302+
assert isinstance(default, (float, int, bool))
303+
assert state.fx_node is not None
304+
acc = self.fn.new_var(f"{state.fx_node.name}_acc", dce=True)
305+
device_loop.outer_prefix.append(
306+
statement_from_string(
307+
f"{acc} = tl.full({shape}, {constant_repr(default)}, {triton_acc_type(acc_dtype)})"
308+
)
309+
)
310+
result = self.fn.new_var(state.fx_node.name, dce=True)
309311
if reduction_type not in {"argmin", "argmax"}:
310-
combine_fn = get_reduction_combine_fn(reduction_type, fake_input.dtype)
312+
combine_fn = get_reduction_combine_fn(reduction_type, acc_dtype)
311313
state.add_statement(f"{acc} = {combine_fn(acc, input_name)}")
312314
expr = self.call_reduction_function(
313315
acc, reduction_type, dim, fake_input, fake_output

test/test_reductions.expected

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,128 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
6262
_launcher(_reduce_kernel_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
6363
return out
6464

65+
--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
66+
from __future__ import annotations
67+
68+
import torch
69+
import triton
70+
import triton.language as tl
71+
from torch._inductor.runtime.triton_compat import libdevice
72+
from helion.runtime import default_launcher as _default_launcher
73+
74+
@triton.jit
75+
def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
76+
pid_0 = tl.program_id(0)
77+
offset_0 = pid_0 * _BLOCK_SIZE_0
78+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
79+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
80+
x_part = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
81+
v_0 = x_part.to(tl.float32)
82+
var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
83+
v_1 = 64
84+
v_2 = var_mean_extra / v_1.to(tl.float32)
85+
v_3 = x_part.to(tl.float32)
86+
v_4 = v_3 - v_2
87+
v_5 = v_4 * v_4
88+
var_mean_extra_2 = tl.reshape(tl.sum(v_5, 1), [_BLOCK_SIZE_0, 1])
89+
v_6 = 64
90+
v_7 = var_mean_extra_2 / v_6.to(tl.float32)
91+
v_8 = v_7.to(tl.float16)
92+
v_9 = v_2.to(tl.float16)
93+
v_10 = x_part - v_9
94+
v_11 = v_8.to(tl.float32)
95+
v_12 = v_11 + eps
96+
v_13 = libdevice.rsqrt(v_12)
97+
v_14 = v_10.to(tl.float32)
98+
v_15 = v_14 * v_13
99+
load_1 = tl.load(weight + indices_1 * 1, None)
100+
v_16 = load_1.to(tl.float32)
101+
v_17 = v_16[None, :]
102+
v_18 = v_15 * v_17
103+
load_2 = tl.load(bias + indices_1 * 1, None)
104+
v_19 = load_2.to(tl.float32)
105+
v_20 = v_19[None, :]
106+
v_21 = v_18 + v_20
107+
v_22 = v_21.to(tl.float16)
108+
tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_22, None)
109+
110+
def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
111+
m, n = x.size()
112+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
113+
_BLOCK_SIZE_0 = 32
114+
_RDIM_SIZE_1 = 64
115+
_launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
116+
return out
117+
118+
--- assertExpectedJournal(TestReductions.test_fp16_var_mean)
119+
from __future__ import annotations
120+
121+
import torch
122+
import triton
123+
import triton.language as tl
124+
from torch._inductor.runtime.triton_compat import libdevice
125+
from helion.runtime import default_launcher as _default_launcher
126+
127+
@triton.jit
128+
def _layer_norm_fwd_repro_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
129+
pid_0 = tl.program_id(0)
130+
offset_0 = pid_0 * _BLOCK_SIZE_0
131+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
132+
var_mean_extra_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
133+
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
134+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
135+
x_part = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
136+
v_0 = x_part.to(tl.float32)
137+
v_1 = var_mean_extra_acc + v_0
138+
var_mean_extra_acc = v_1
139+
var_mean_extra = tl.reshape(tl.sum(var_mean_extra_acc, 1), [_BLOCK_SIZE_0, 1])
140+
v_2 = 64
141+
v_3 = var_mean_extra / v_2.to(tl.float32)
142+
var_mean_extra_2_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
143+
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
144+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
145+
v_3_copy = v_3
146+
x_part_1 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
147+
v_4 = x_part_1.to(tl.float32)
148+
v_5 = v_4 - v_3_copy
149+
v_6 = v_5 * v_5
150+
v_7 = var_mean_extra_2_acc + v_6
151+
var_mean_extra_2_acc = v_7
152+
var_mean_extra_2 = tl.reshape(tl.sum(var_mean_extra_2_acc, 1), [_BLOCK_SIZE_0, 1])
153+
v_8 = 64
154+
v_9 = var_mean_extra_2 / v_8.to(tl.float32)
155+
v_10 = v_9.to(tl.float16)
156+
v_11 = v_3.to(tl.float16)
157+
v_12 = v_10.to(tl.float32)
158+
v_13 = v_12 + eps
159+
v_14 = libdevice.rsqrt(v_13)
160+
for roffset_1 in tl.range(0, 64, _REDUCTION_BLOCK_1):
161+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
162+
v_11_copy = v_11
163+
v_14_copy = v_14
164+
x_part_2 = tl.load(x + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), None)
165+
v_15 = x_part_2 - v_11_copy
166+
v_16 = v_15.to(tl.float32)
167+
v_17 = v_16 * v_14_copy
168+
load_1 = tl.load(weight + rindex_1 * 1, None)
169+
v_18 = load_1.to(tl.float32)
170+
v_19 = v_18[None, :]
171+
v_20 = v_17 * v_19
172+
load_2 = tl.load(bias + rindex_1 * 1, None)
173+
v_21 = load_2.to(tl.float32)
174+
v_22 = v_21[None, :]
175+
v_23 = v_20 + v_22
176+
v_24 = v_23.to(tl.float16)
177+
tl.store(out + (indices_0[:, None] * 64 + rindex_1[None, :] * 1), v_24, None)
178+
179+
def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
180+
m, n = x.size()
181+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
182+
_BLOCK_SIZE_0 = 32
183+
_REDUCTION_BLOCK_1 = 8
184+
_launcher(_layer_norm_fwd_repro_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
185+
return out
186+
65187
--- assertExpectedJournal(TestReductions.test_mean)
66188
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
67189
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48>)

test/test_reductions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,48 @@ def layer_norm_reduction(
188188
)
189189
self.assertExpectedJournal(code)
190190

191+
def test_fp16_var_mean(self):
192+
@helion.kernel(static_shapes=True)
193+
def layer_norm_fwd_repro(
194+
x: torch.Tensor,
195+
weight: torch.Tensor,
196+
bias: torch.Tensor,
197+
eps: float = 1e-5,
198+
) -> torch.Tensor:
199+
m, n = x.size()
200+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
201+
for tile_m in hl.tile(m):
202+
x_part = x[tile_m, :]
203+
var, mean = torch.var_mean(x_part, dim=-1, keepdim=True, correction=0)
204+
normalized = (x_part - mean) * torch.rsqrt(var.to(torch.float32) + eps)
205+
out[tile_m, :] = normalized * (weight[:].to(torch.float32)) + (
206+
bias[:].to(torch.float32)
207+
)
208+
return out
209+
210+
batch_size = 32
211+
dim = 64
212+
x = torch.randn([batch_size, dim], device=DEVICE, dtype=torch.float16)
213+
weight = torch.randn([dim], device=DEVICE, dtype=torch.float16)
214+
bias = torch.randn([dim], device=DEVICE, dtype=torch.float16)
215+
eps = 1e-4
216+
code1, result1 = code_and_output(
217+
layer_norm_fwd_repro,
218+
(x, weight, bias, eps),
219+
block_sizes=[32],
220+
reduction_loops=[None],
221+
)
222+
self.assertExpectedJournal(code1)
223+
224+
code2, result2 = code_and_output(
225+
layer_norm_fwd_repro,
226+
(x, weight, bias, eps),
227+
block_sizes=[32],
228+
reduction_loops=[8],
229+
)
230+
self.assertExpectedJournal(code2)
231+
torch.testing.assert_close(result1, result2, rtol=1e-3, atol=1e-3)
232+
191233

192234
if __name__ == "__main__":
193235
unittest.main()

0 commit comments

Comments
 (0)