Skip to content

Commit 9b27dff

Browse files
authored
[Frontend] Revert deep copy when entering new scope (triton-lang#8832)
This was added when we were attempting to support mutations in the frontend. However, now that we ban mutations it's nothing but an added cost during compilation. With this change I see a 250ms improvement in compilation time for the attention gluon example kernel.
1 parent 8e083aa commit 9b27dff

File tree

2 files changed

+2
-58
lines changed

2 files changed

+2
-58
lines changed

python/test/unit/language/test_frontend.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,6 @@ class TypeWithBuiltinInitializer:
152152
def __init__(self, _semantic=None):
153153
self.value = tl.arange(0, 4, _semantic=_semantic)
154154

155-
def modify(self, value, _semantic=None):
156-
self.value = value
157-
158155

159156
@filecheck_test
160157
@triton.jit
@@ -164,48 +161,6 @@ def test_aggregate_initializers():
164161
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
165162
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
166163
anchor(value)
167-
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
168-
# CHECK: call @{{.*}}anchor{{.*}}([[RANGE]])
169-
value.modify(tl.arange(4, 8))
170-
anchor(value)
171-
172-
173-
@filecheck_test
174-
@triton.jit
175-
def test_aggregate_modification_in_for_loop():
176-
# CHECK-LABEL: test_aggregate_modification_in_for_loop
177-
value = TypeWithBuiltinInitializer()
178-
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
179-
for i in range(0, 2):
180-
# CHECK: [[RET:%.*]] = scf.for
181-
# CHECK-SAME: iter_args([[ITER:%.*]] = [[RANGE]])
182-
value.modify(tl.arange(4, 8))
183-
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
184-
# CHECK: yield [[RANGE]]
185-
186-
anchor(value)
187-
# CHECK: call @{{.*}}anchor{{.*}}([[RET]])
188-
189-
190-
@filecheck_test
191-
@triton.jit
192-
def test_aggregate_modification_in_while_loop():
193-
# CHECK-LABEL: test_aggregate_modification_in_while_loop
194-
value = TypeWithBuiltinInitializer()
195-
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
196-
i = 0
197-
# CHECK: [[C0:%.*]] = arith.constant 0 :
198-
while i < 1:
199-
# CHECK: [[RET:%.*]]:2 = scf.while ([[ITER:%.*]] = [[RANGE]], [[IV:%.*]] = [[C0]])
200-
# CHECK: do
201-
i = 1
202-
# CHECK: [[C1:%.*]] = arith.constant 1 :
203-
value.modify(tl.arange(4, 8))
204-
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
205-
# CHECK: yield [[RANGE]], [[C1]]
206-
207-
anchor(value)
208-
# CHECK: call @{{.*}}anchor{{.*}}([[RET]]#0)
209164

210165

211166
@triton.jit

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,26 +110,15 @@ def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
110110
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
111111

112112

113-
def _clone_triton_value(val):
114-
handles = []
115-
val._flatten_ir(handles)
116-
clone, _ = val.type._unflatten_ir(handles, 0)
117-
return clone
118-
119-
120-
def _clone_scope(scope):
121-
return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
122-
123-
124113
class enter_sub_region:
125114

126115
def __init__(self, generator):
127116
self.generator = generator
128117

129118
def __enter__(self):
130119
# record lscope & local_defs in the parent scope
131-
self.liveins = _clone_scope(self.generator.lscope)
132-
self.prev_defs = _clone_scope(self.generator.local_defs)
120+
self.liveins = dict(self.generator.lscope)
121+
self.prev_defs = dict(self.generator.local_defs)
133122
self.generator.local_defs = {}
134123
self.insert_block = self.generator.builder.get_insertion_block()
135124
self.insert_point = self.generator.builder.get_insertion_point()

0 commit comments

Comments
 (0)