Skip to content

Commit 7b9a467

Browse files
authored
Add unplaced IRON design for conv2d 14x14 (#2601)
1 parent 0aa65a6 commit 7b9a467

File tree

10 files changed

+312
-49
lines changed

10 files changed

+312
-49
lines changed

aie_kernels/aie2p/conv2dk14.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,15 @@ void conv2dk14_i8_scalar(uint8_t *input, int8_t *kernels, int8_t *output,
6060
int wts_indx = 0;
6161
int out_indx = 0;
6262

63-
const int output_channels_div_8 = output_channels / 8;
64-
const int tiles_div_8 = input_width / kernel_width / 8;
65-
const int pixels_div_2 = kernel_width * kernel_width / 2;
63+
const int output_channels_div_8 = output_channels / 8; // 2
64+
const int tiles_div_8 = input_width / kernel_width / 8; // 2
65+
const int pixels_div_2 = kernel_width * kernel_width / 2; // 98
6666

6767
for (oc = 0; oc < output_channels_div_8; oc++) { // 16 out of 1152
6868
for (oc8 = 0; oc8 < 8; oc8++) {
6969
for (nt = 0; nt < tiles_div_8; nt++) { // 16 out of 64 tiles in row
7070
for (nt8 = 0; nt8 < 8; nt8++) {
7171
int sum = 0;
72-
int sum_srs = 0;
7372
for (pix = 0; pix < pixels_div_2; pix++) { // 196 // 2 = 98
7473
for (p2 = 0; p2 < 2; p2++) {
7574
in_indx = ((nt * (pixels_div_2) * 8 * 2) + (pix * 8 * 2) +
@@ -83,7 +82,7 @@ void conv2dk14_i8_scalar(uint8_t *input, int8_t *kernels, int8_t *output,
8382
input[in_indx + 3] * kernels[wts_indx + 24];
8483
}
8584
}
86-
sum_srs = (sum + (1 << (scale - 1))) >> scale;
85+
int sum_srs = (sum + (1 << (scale - 1))) >> scale;
8786
sum_srs = (sum_srs > SMAX) ? SMAX
8887
: (sum_srs < -SMIN) ? -SMIN
8988
: sum_srs;
@@ -154,7 +153,7 @@ void conv2dk14_i8_vector(uint8_t *input, int8_t *kernels, int8_t *output,
154153
int8_t *__restrict out_ptr = output;
155154

156155
for (int k = 0; k < output_channels_div_8; k++) { // 2
157-
for (int j = 0; j < tiles_div_16; j++) { // 2
156+
for (int j = 0; j < tiles_div_16; j++) { // 1
158157
AIE_PREPARE_FOR_PIPELINING
159158
AIE_LOOP_MIN_ITERATION_COUNT(98)
160159
// AIE_LOOP_UNROLL_FULL

programming_examples/ml/conv2d_14x14/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ ifeq ($(CHESS), true)
8989
cd ${@D} && aiecc.py --aie-generate-xclbin --aie-generate-npu-insts --no-compile-host \
9090
--xclbin-name=${@F} --npu-insts-name=insts_trace.bin $(<:%=../%)
9191
else
92-
cd ${@D} && aiecc.py --aie-generate-xclbin --aie-generate-npu-insts --no-compile-host \
92+
cd ${@D} && aiecc.py -v --aie-generate-xclbin --aie-generate-npu-insts --no-compile-host --packet-sw-objFifos \
9393
--no-xchesscc --no-xbridge --xclbin-name=${@F} --npu-insts-name=insts_trace.bin $(<:%=../%)
9494
endif
9595

programming_examples/ml/conv2d_14x14/README.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,24 @@ To compile and run the design:
6565
make run_py
6666
```
6767

68-
6968
To build and run the design while generating trace
7069
```shell
71-
make trace_py
70+
make clean; make trace_py
71+
```
72+
73+
To build and run the design for the unplaced IRON version (with or without generating trace), we need to add an additional qualifier.
74+
```shell
75+
make clean; make use_placed=0 num_act=72 run_py
76+
make clean; make use_placed=0 num_act=72 trace_py
7277
```
7378

7479
To build and run the 32-core design (trace not currently supported)
7580
```shell
76-
make targetname=conv2dk14_32core num_act=1 run_py
81+
make clean; make targetname=conv2dk14_32core num_act=1 run_py
82+
```
83+
To build an drun the 32-core design with the scalar kernel (trace not currently supported)
84+
```shell
85+
make clean; make vectorized=false targetname=conv2dk14_32core num_act=1 run_py
7786
```
7887

7988
## Multi-core Design Example (32-cores)
@@ -86,8 +95,8 @@ While the design was designed to be somewhat configurable, this is mostly tested
8695

8796
## Limitation Notes
8897
At the moment, the following limtations exist:
89-
* The scalar kernel version of this design has some intermittent runtime issue (CMD_ABORT triggered) for the full output channel size. Reducing this to 256 channels from 1152 is a workaround at the moment but further investigation is needed to fully resolve this.
90-
* Unplaced IRON version is in the works. At the moment, writing trace data to the 5th buffer which is the default for unplaced IRON seems to trigger a segfault. Further investgation needed.
98+
* The scalar kernel version of this design does not run properly in single core mode for the full data size because the total compute time exceeds the execution time limit of the npu driver (~2 seconds). You can reduce the number of output channels (576 channels works) or you can run the scalar kernel with the 32-core design as noted above.
99+
* Unplaced IRON now works but needs an additional qualifier for the testbench. However, there is a bug if the trace_size is 32,768 bytes (rather than 16kB or 8kB) which causes the unplaced IRON trace to seg fault. Still under investiation but choosing a smaller size seems to be a good workaround.
91100
* Trace for the 32-core variant currently causes the compilation to hang. Under investigation but the non-trace run works without issue.
92101
* There is behavior bug where the number of input/activation sets sent from the host to the AIE array needs to be a certain value in order for correct functionality. For the single core design, `num_act=2` is sufficient for non-trace runs (`run_py`) but for trace runs (`trace_py`), we need this to be `num_act=8`. For the 32-core design, `num_act=1` is sufficient but any value for trace runs causes it to hang at the moment. This is under investigation.
93102

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#
2+
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
3+
# See https://llvm.org/LICENSE.txt for license information.
4+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
#
6+
# (c) Copyright 2024 AMD Inc.
7+
import numpy as np
8+
import sys
9+
10+
from aie.iron import (
11+
GlobalBuffer,
12+
Kernel,
13+
ObjectFifo,
14+
Program,
15+
Runtime,
16+
Worker,
17+
WorkerRuntimeBarrier,
18+
)
19+
from aie.iron.placers import SequentialPlacer
20+
from aie.iron.device import NPU1Col1, NPU2Col1
21+
from aie.iron.controlflow import range_
22+
23+
24+
def conv2dk14(
25+
dev,
26+
width: int,
27+
height: int,
28+
in_channels: int,
29+
out_channels: int,
30+
kernel_size: int,
31+
trace_size: int,
32+
):
33+
enable_trace = 1 if trace_size > 0 else 0
34+
35+
# Kernel processes 16 tiles and 16 output channels at a time
36+
sub_out_channels = 16
37+
sub_tiles = 16
38+
39+
actIn = kernel_size * kernel_size * in_channels * sub_tiles
40+
weights = kernel_size * kernel_size * in_channels * sub_out_channels
41+
actOut = sub_tiles * sub_out_channels
42+
43+
out_channels_group = out_channels // sub_out_channels # 72
44+
width_out = width // kernel_size
45+
height_out = height // kernel_size
46+
47+
# we reload inputs 72 times (out_channels // sub_out_channels)
48+
tensorInSize = width * height * in_channels * out_channels_group
49+
# tensorInSize = width * height * in_channels * 2
50+
51+
tensorWeightsSize = weights * out_channels_group
52+
tensorOutSize = width_out * height_out * sub_out_channels * out_channels_group
53+
54+
N_in_bytes = tensorOutSize # Number of bytes of output data (1 byte/elem)
55+
56+
bufIn = kernel_size * width * in_channels
57+
bufOut = sub_out_channels * width_out * height_out
58+
59+
# Type definitions
60+
actIn_ty = np.ndarray[(actIn,), np.dtype[np.uint8]]
61+
bufIn_ty = np.ndarray[(bufIn,), np.dtype[np.uint8]]
62+
63+
weights_ty = np.ndarray[(weights,), np.dtype[np.int8]]
64+
65+
out_ty = np.ndarray[(actOut,), np.dtype[np.int8]]
66+
bufOut_ty = np.ndarray[(bufOut,), np.dtype[np.int8]]
67+
tensorIn_ty = np.ndarray[(tensorInSize,), np.dtype[np.uint8]]
68+
tensorWeights_ty = np.ndarray[(tensorWeightsSize,), np.dtype[np.int8]]
69+
tensorOut_ty = np.ndarray[(tensorOutSize,), np.dtype[np.int8]]
70+
71+
# AIE Core Function declarations
72+
conv2dk14_i8_kernel = Kernel(
73+
"conv2dk14_i8",
74+
"conv2dk14.o",
75+
[
76+
actIn_ty,
77+
weights_ty,
78+
out_ty,
79+
np.int32,
80+
np.int32,
81+
np.int32,
82+
np.int32,
83+
np.int32,
84+
],
85+
)
86+
87+
# AIE-array data movement with object fifos
88+
# Input
89+
of_inOF_act_L3L2 = ObjectFifo(
90+
bufIn_ty,
91+
name="inOF_act_L3L2",
92+
dims_from_stream_per_cons=[
93+
(kernel_size, kernel_size * in_channels), # (14, 56)
94+
(64, kernel_size * kernel_size * in_channels), # (64, 784)
95+
(kernel_size * in_channels, 1), # (56, 1)
96+
],
97+
)
98+
of_act_L2_02 = of_inOF_act_L3L2.cons().forward(
99+
obj_type=actIn_ty,
100+
name="act_L2_02",
101+
dims_to_stream=[
102+
(2, kernel_size * kernel_size * in_channels * 8), # (2, 6272)
103+
(kernel_size * kernel_size // 2, 2 * in_channels), # (98, 8)
104+
(8, kernel_size * kernel_size * in_channels), # (8, 784)
105+
(2 * in_channels, 1), # (8, 1)
106+
],
107+
)
108+
109+
# wts
110+
of_inOF_wts_0_L3L2 = ObjectFifo(weights_ty, depth=1, name="inOF_wts_0_L3L2")
111+
112+
# Output
113+
of_out_02_L2 = ObjectFifo(out_ty, name="out_02_L2")
114+
of_outOFL2L3 = of_out_02_L2.cons().forward(
115+
obj_type=bufOut_ty,
116+
name="outOFL2L3",
117+
dims_to_stream=[(256, 256), (16, 8), (2, 128), (8, 1)],
118+
)
119+
120+
# Setup a global buffer to hold runtime parameters
121+
# rtp = GlobalBuffer(
122+
# np.ndarray[(16,), np.dtype[np.int32]],
123+
# name="rtp",
124+
# use_write_rtp=True,
125+
# )
126+
127+
# rtp_barrier = WorkerRuntimeBarrier()
128+
129+
# Task for the core to perform
130+
# def core_fn(of_wts, of_act, of_out, my_rtp, conv2dk14_i8, barrier):
131+
def core_fn(of_wts, of_act, of_out, conv2dk14_i8):
132+
y_dim = height // kernel_size
133+
x_blocks = 4
134+
x_dim = width // x_blocks # num pixels for 1/4 of a row
135+
ci = in_channels
136+
co = sub_out_channels
137+
138+
# barrier.wait_for_value(1)
139+
# scale = my_rtp[0]
140+
scale = 14
141+
142+
elemWts = of_wts.acquire(1)
143+
144+
for _ in range_(y_dim):
145+
for _ in range_(x_blocks):
146+
elemIn = of_act.acquire(1)
147+
elemOut0 = of_out.acquire(1)
148+
149+
conv2dk14_i8(
150+
elemIn, elemWts, elemOut0, x_dim, ci, co, kernel_size, scale
151+
)
152+
of_act.release(1)
153+
of_out.release(1)
154+
of_wts.release(1)
155+
156+
# Create a worker to perform the task
157+
worker = Worker(
158+
core_fn,
159+
[
160+
of_inOF_wts_0_L3L2.cons(),
161+
of_act_L2_02.cons(),
162+
of_out_02_L2.prod(),
163+
# rtp,
164+
conv2dk14_i8_kernel,
165+
# rtp_barrier,
166+
],
167+
stack_size=0x600,
168+
trace=enable_trace,
169+
)
170+
171+
# Runtime operations to move data to/from the AIE-array
172+
rt = Runtime()
173+
with rt.sequence(tensorIn_ty, tensorWeights_ty, tensorOut_ty) as (I, W, O):
174+
# Initialize the runtime parameter values
175+
def set_rtps(my_rtp):
176+
my_rtp[0] = 14
177+
178+
# rt.inline_ops(set_rtps, [rtp])
179+
180+
# rt.set_barrier(rtp_barrier, 1)
181+
182+
rt.enable_trace(trace_size),
183+
184+
# Start worker
185+
rt.start(worker)
186+
187+
# Fill/drain input/output ObjectFifos
188+
rt.fill(of_inOF_act_L3L2.prod(), I)
189+
rt.fill(of_inOF_wts_0_L3L2.prod(), W)
190+
rt.drain(of_outOFL2L3.cons(), O, wait=True)
191+
192+
# Place components (assign them resources on the device) and generate an MLIR module
193+
return Program(dev, rt).resolve_program(SequentialPlacer())
194+
195+
196+
try:
197+
device_name = str(sys.argv[1])
198+
if device_name == "npu":
199+
dev = NPU1Col1()
200+
elif device_name == "npu2":
201+
dev = NPU2Col1()
202+
else:
203+
raise ValueError("[ERROR] Device name {} is unknown".format(sys.argv[1]))
204+
width = int(sys.argv[2])
205+
if width % 8 != 0 or width < 8:
206+
print("Width size must be a multiple of 8 and greater than or equal to 8")
207+
raise ValueError
208+
height = int(sys.argv[3])
209+
if height % 8 != 0 or height < 8:
210+
print("Height size must be a multiple of 8 and greater than or equal to 8")
211+
raise ValueError
212+
in_channels = int(sys.argv[4])
213+
if in_channels != 4:
214+
print("Input channels size must be equal to 4")
215+
raise ValueError
216+
out_channels = int(sys.argv[5])
217+
if out_channels != 1152:
218+
print("Output channel size must be equal to 1152")
219+
raise ValueError
220+
kernel_size = int(sys.argv[6])
221+
if kernel_size != 14:
222+
print("Kernel size must be 14 right now.")
223+
raise ValueError
224+
trace_size = 0 if (len(sys.argv) != 8) else int(sys.argv[7])
225+
except ValueError:
226+
print("Argument has inappropriate value")
227+
module = conv2dk14(
228+
dev, width, height, in_channels, out_channels, kernel_size, trace_size
229+
)
230+
print(module)

programming_examples/ml/conv2d_14x14/conv2dk14_32core_placed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ def sequence(I, W, O):
268268
tiles_to_trace=tiles_to_trace,
269269
shim=shim_tiles[0],
270270
trace_size=trace_size,
271-
trace_offset=N_in_bytes,
272-
ddr_id=2,
271+
# trace_offset=N_in_bytes,
272+
# ddr_id=2,
273273
coretile_events=[
274274
CoreEvent.INSTR_EVENT_0,
275275
CoreEvent.INSTR_EVENT_1,

programming_examples/ml/conv2d_14x14/conv2dk14_placed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,6 @@ def sequence(I, W, O):
207207
tiles_to_trace=tiles_to_trace,
208208
shim=ShimTile,
209209
trace_size=trace_size,
210-
trace_offset=N_in_bytes,
211-
ddr_id=2,
212210
coretile_events=[
213211
CoreEvent.INSTR_EVENT_0,
214212
CoreEvent.INSTR_EVENT_1,

programming_examples/ml/conv2d_14x14/run_strix_makefile_placed.lit

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,10 @@
1515
// RUN: make -f %S/Makefile clean
1616
// RUN: env num_act=8 %run_on_npu2% make -f %S/Makefile trace_py devicename=npu2
1717
// RUN: make -f %S/Makefile clean
18+
// RUN: env num_act=72 %run_on_npu2% make -f %S/Makefile use_placed=0 run_py devicename=npu2
19+
// RUN: make -f %S/Makefile clean
20+
// RUN: env num_act=72 %run_on_npu2% make -f %S/Makefile use_placed=0 trace_py devicename=npu2
21+
// RUN: make -f %S/Makefile clean
1822
// RUN: env targetname=conv2dk14_32core num_act=1 %run_on_npu2% make -f %S/Makefile run_py devicename=npu2
23+
// RUN: make -f %S/Makefile clean
24+
// RUN: env targetname=conv2dk14_32core num_act=1 vectorized=false %run_on_npu2% make -f %S/Makefile run_py devicename=npu2

programming_examples/ml/conv2d_14x14/test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def main(opts):
121121
dtype_out,
122122
enable_trace=enable_trace,
123123
trace_size=trace_size,
124-
trace_after_output=True,
124+
trace_after_output=False,
125125
)
126126

127127
# ------------------------------------------------------
@@ -212,15 +212,18 @@ def forward(self, x):
212212
# ------------------------------------------------------
213213
for i in range(num_iter):
214214
start = time.time_ns()
215-
# entire_buffer = execute(app, ifm_mem_fmt, total_wts)
216-
entire_buffer = execute(app, ifm_mem_fmt_grp, total_wts)
215+
if enable_trace:
216+
data_buffer, trace_buffer = execute(
217+
app, ifm_mem_fmt_grp, total_wts, enable_trace, False
218+
)
219+
else:
220+
entire_buffer = execute(
221+
app, ifm_mem_fmt_grp, total_wts, enable_trace, False
222+
)
217223
stop = time.time_ns()
218224

219225
if enable_trace:
220-
# Separate data and trace
221-
data_buffer, trace_buffer = extract_trace(
222-
entire_buffer, shape_out, dtype_out, trace_size
223-
)
226+
trace_buffer = trace_buffer.view(np.uint32)
224227
# Scale the data
225228
scaled_data_buffer = data_buffer * int8_scale
226229
# Write out the trace

0 commit comments

Comments
 (0)