Skip to content

Commit 588c3b9

Browse files
authored
Fix padding in SwiGLU operator; fix call to softmax in llama (#50)
1 parent 140ffa6 commit 588c3b9

File tree

8 files changed

+137
-77
lines changed

8 files changed

+137
-77
lines changed

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __init__(
8484
self.aie_softmax = AIESoftmax(
8585
num_aie_columns=1,
8686
num_channels=1,
87-
size=prompt_length * prompt_length,
88-
last_dim=prompt_length,
87+
rows=prompt_length,
88+
cols=prompt_length,
8989
)
9090
M_for_gemm = prompt_length + num_tokens
9191
self.aie_mha_gemm_qk = AIEGEMM(

applications/llama_3.2_1b/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
def generate_test_params():
14-
prompt_lengths = [2048]
15-
num_tokens_list = [40]
14+
prompt_lengths = [2048, 13]
15+
num_tokens_list = [40, 1]
1616

1717
params = []
1818
names = []

operators/common/aie_base.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -167,23 +167,46 @@ def _move_artifact_paths(self):
167167
todo.extend(artifact.depends)
168168

169169
def run_runlist(self):
170-
bos = set(
171-
self.buffer_bos[buffer_arg]
172-
for _, *buffer_args in self.runlist
173-
for buffer_arg in buffer_args
174-
)
175-
insts_bos = set(
176-
self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist
177-
)
178-
for bo in bos | insts_bos:
179-
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
180-
start = time.perf_counter()
181-
self.xrt_runlist.execute()
182-
self.xrt_runlist.wait()
183-
stop = time.perf_counter()
184-
for bo in bos:
185-
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
186-
return stop - start
170+
elapsed = 0.0
171+
if self.xrt_runlist is None:
172+
# Execute as separate xclbin kernel invocations
173+
for i, (kernel_name, *buffer_args) in enumerate(self.runlist):
174+
context, xrt_kernel, insts_bo, insts_len = self.xrt_kernels[kernel_name]
175+
insts_bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
176+
bos = [self.buffer_bos[buffer_arg] for buffer_arg in buffer_args]
177+
for bo in bos:
178+
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
179+
opcode = 3
180+
start = time.perf_counter()
181+
run = xrt_kernel(opcode, insts_bo, insts_len, *bos)
182+
result = run.wait()
183+
stop = time.perf_counter()
184+
elapsed += stop - start
185+
if result != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED:
186+
raise RuntimeError(
187+
f"Kernel {kernel_name} did not complete correctly: {result}"
188+
)
189+
for bo in bos:
190+
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
191+
else:
192+
bos = set(
193+
self.buffer_bos[buffer_arg]
194+
for _, *buffer_args in self.runlist
195+
for buffer_arg in buffer_args
196+
)
197+
insts_bos = set(
198+
self.xrt_kernels[kernel_name][2] for (kernel_name, *_) in self.runlist
199+
)
200+
for bo in bos | insts_bos:
201+
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
202+
start = time.perf_counter()
203+
self.xrt_runlist.execute()
204+
self.xrt_runlist.wait()
205+
stop = time.perf_counter()
206+
for bo in bos:
207+
bo.sync(pyxrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
208+
elapsed = stop - start
209+
return elapsed
187210

188211

189212
class AIEOperatorConstraintError(RuntimeError):

operators/common/aie_context.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414
class AIEContext:
1515
"""Context for managing AIE operator compilation and runtime state"""
1616

17-
def __init__(self):
17+
def __init__(self, use_runlist=True):
1818
self.operators = []
1919
self.static_data_pool = {}
2020
self.device_manager = AIEDeviceManager()
2121
self.base_dir = Path(__file__).parent.parent.parent
2222
self.build_dir = Path(os.getcwd()) / "build"
2323
self.mlir_aie_dir = Path(aie.utils.config.root_path())
2424
self.peano_dir = Path(aie.utils.config.peano_install_dir())
25+
# Disable the XRT runlist sacrifices performance by executing kernels individually as separate xclbin invocations for easier debugging (can tell which part of runlist execution failed)
26+
self.use_runlist = use_runlist
2527
self._runtime_prepared = False
2628

2729
def register_operator(self, operator):
@@ -146,20 +148,23 @@ def prepare_runtime(self):
146148
context, _ = self.device_manager.get_context_and_kernel(
147149
str(first_xclbin.path), first_xclbin_kernel_name
148150
)
149-
op.xrt_runlist = pyxrt.runlist(context)
150-
for i, (kernel_name, *buffer_args) in enumerate(op.runlist):
151-
this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[
152-
kernel_name
153-
]
154-
assert this_context == context
155-
opcode = 3
156-
run = pyxrt.run(xrt_kernel)
157-
run.set_arg(0, opcode)
158-
run.set_arg(1, insts_bo)
159-
run.set_arg(2, insts_len)
160-
for j, buffer_arg in enumerate(buffer_args):
161-
run.set_arg(j + 3, op.buffer_bos[buffer_arg])
162-
op.xrt_runlist.add(run)
151+
if self.use_runlist:
152+
op.xrt_runlist = pyxrt.runlist(context)
153+
for i, (kernel_name, *buffer_args) in enumerate(op.runlist):
154+
this_context, xrt_kernel, insts_bo, insts_len = op.xrt_kernels[
155+
kernel_name
156+
]
157+
assert this_context == context
158+
opcode = 3
159+
run = pyxrt.run(xrt_kernel)
160+
run.set_arg(0, opcode)
161+
run.set_arg(1, insts_bo)
162+
run.set_arg(2, insts_len)
163+
for j, buffer_arg in enumerate(buffer_args):
164+
run.set_arg(j + 3, op.buffer_bos[buffer_arg])
165+
op.xrt_runlist.add(run)
166+
else:
167+
op.xrt_runlist = None
163168

164169
# Log allocation info
165170
bo_count = sum(len(pool) for pool in bo_pools.values())

operators/softmax/op.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,9 @@
2121
class AIESoftmax(AIEOperatorBase):
2222

2323
def __init__(
24-
self,
25-
rows: int,
26-
cols: int,
27-
num_aie_columns=1,
28-
num_channels=1,
29-
tile_size=None,
30-
context=None,
24+
self, rows: int, cols: int, num_aie_columns=1, num_channels=1, context=None
3125
):
3226
self.size = rows * cols
33-
self.tile_size = tile_size if tile_size is not None else cols
3427
self.rows = rows
3528
self.cols = cols
3629

@@ -46,19 +39,19 @@ def __init__(
4639
def set_up_artifacts(self):
4740
# Compilation artifacts
4841
operator_dir = Path(__file__).parent
49-
file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t"
42+
file_name_base = f"softmax_{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.cols}t"
5043

5144
mlir_artifact = PythonGeneratedMLIRArtifact.new(
5245
f"{file_name_base}.mlir",
5346
import_path=operator_dir / "design.py",
5447
callback_fn="softmax",
5548
callback_args=[
5649
self.context.device_manager.device_type,
57-
self.size,
50+
self.rows * self.cols,
5851
self.num_columns,
5952
self.num_channels,
6053
0,
61-
self.tile_size,
54+
self.cols,
6255
],
6356
)
6457

@@ -105,7 +98,7 @@ def set_up_runtime(self):
10598
def forward(self, x):
10699
applicable = (
107100
x.shape[-1] * x.shape[-2] == self.size
108-
and x.shape[-1] == self.tile_size
101+
and x.shape[-1] == self.cols
109102
and x.shape[-1] % 16 == 0
110103
and x.shape[-2] % 16 == 0
111104
)

operators/softmax/test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def test_softmax(input_length, num_aie_columns, num_channels, tile_size, aie_con
8484
cols=cols,
8585
num_aie_columns=num_aie_columns,
8686
num_channels=num_channels,
87-
tile_size=tile_size,
8887
context=aie_context,
8988
)
9089

operators/swiglu_decode/op.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None)
4646
super().__init__(context=context)
4747

4848
def set_up_artifacts(self):
49-
# Artifact setup
50-
# ---
5149
artifacts = []
5250
device_str = self.context.device_manager.device_str()
5351

@@ -57,6 +55,7 @@ def set_up_artifacts(self):
5755
num_aie_columns=8,
5856
tile_size=1,
5957
)
58+
self.gemv_1 = gemv_1
6059
gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts(
6160
prefix="swiglu_decode_gemv_1_"
6261
)
@@ -75,6 +74,8 @@ def set_up_artifacts(self):
7574
num_channels=2,
7675
tile_size=self.hidden_dim // 16,
7776
)
77+
self.silu = silu
78+
self.hidden_dim_padded = silu.size
7879
silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_decode_silu_")
7980
silu_xclbin.xclbin_input = gemv_1_xclbin
8081
silu_xclbin.extra_flags += [
@@ -91,6 +92,8 @@ def set_up_artifacts(self):
9192
num_channels=2,
9293
tile_size=self.hidden_dim // 8,
9394
)
95+
self.eltwise_mul = eltwise_mul
96+
assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded
9497
eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts(
9598
prefix="swiglu_decode_eltwise_mul_"
9699
)
@@ -109,6 +112,7 @@ def set_up_artifacts(self):
109112
num_aie_columns=8,
110113
tile_size=1,
111114
)
115+
self.gemv_2 = gemv_2
112116
gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts(
113117
prefix="swiglu_decode_gemv_2_"
114118
)
@@ -135,28 +139,26 @@ def set_up_artifacts(self):
135139
self.add_artifacts(artifacts)
136140

137141
def set_up_runtime(self):
138-
# Runtime setup
139-
# ---
140142
self.add_buffer("input", self.embedding_dim)
141143
self.add_buffer(
142144
"weights_1",
143-
self.embedding_dim * self.hidden_dim,
145+
self.embedding_dim * self.hidden_dim_padded,
144146
static_data=torch_to_numpy(self.weights_1),
145147
)
146148
self.add_buffer(
147149
"weights_2",
148-
self.embedding_dim * self.hidden_dim,
150+
self.embedding_dim * self.hidden_dim_padded,
149151
static_data=torch_to_numpy(self.weights_2),
150152
)
151153
self.add_buffer(
152154
"weights_3",
153-
self.hidden_dim * self.embedding_dim,
155+
self.hidden_dim_padded * self.embedding_dim,
154156
static_data=torch_to_numpy(self.weights_3),
155157
)
156-
self.add_buffer("left", self.hidden_dim)
157-
self.add_buffer("left_swished", self.hidden_dim)
158-
self.add_buffer("right", self.hidden_dim)
159-
self.add_buffer("intermediate", self.hidden_dim)
158+
self.add_buffer("left", self.hidden_dim_padded)
159+
self.add_buffer("left_swished", self.hidden_dim_padded)
160+
self.add_buffer("right", self.hidden_dim_padded)
161+
self.add_buffer("intermediate", self.hidden_dim_padded)
160162
self.add_buffer("output", self.embedding_dim)
161163
self.add_kernel(
162164
"swiglu_gemv_1",
@@ -191,9 +193,7 @@ def set_up_runtime(self):
191193
self.add_to_runlist("swiglu_gemv_2", "weights_3", "intermediate", "output")
192194

193195
def forward(self, x):
194-
# Turn into a numpy vector and drop the batch and other higher dimensions, if any; will error if batch or other higher dimensions > 1
195196
x_flat = x.reshape(x.shape[-1])
196-
197197
assert x_flat.shape[0] == self.embedding_dim
198198

199199
self.write_buffer("input", x_flat)

0 commit comments

Comments
 (0)