Skip to content

Commit f4b1d8c

Browse files
authored
Add example and document JIT vs. non-JIT (#2628)
1 parent 5950a4d commit f4b1d8c

File tree

5 files changed

+217
-17
lines changed

5 files changed

+217
-17
lines changed

programming_examples/basic/vector_reduce_min/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ else
5757
endif
5858

5959
run: ${targetname}.exe build/final.xclbin
60-
${powershell} ./$< -x build/final.xclbin -i build/insts.bin -k MLIR_AIE
60+
${powershell} ./$< -x build/final.xclbin -i build/insts.bin -k MLIR_AIE --warmup 10 --iters 20
6161

6262
trace:
6363
../../../python/utils/parse_trace.py --input trace.txt --mlir build/aie.mlir --output parse_eventIR_vs.json

programming_examples/basic/vector_reduce_min/README.md

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,62 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
7-
// Copyright (C) 2024, Advanced Micro Devices, Inc.
7+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc.
88
//
99
//===----------------------------------------------------------------------===//-->
1010

1111
# Vector Reduce Min:
1212

13-
Single tile performs a very simple reduction operation where the kernel loads data from local memory, performs the `min` reduction and stores the resulting value back.
13+
This example showcases both **JIT** and **non-JIT** approaches for running IRON designs. A single tile performs a very simple reduction operation where the kernel loads data from local memory, performs the `min` reduction and stores the resulting value back.
1414

15-
Input data is brought to the local memory of the Compute tile from a Shim tile. The size of the input data `N` from the Shim tile is `1024xi32`. The data is copied to the AIE tile, where the reduction is performed. The single output data value is copied from the AIE tile to the Shim tile.
15+
Input data is brought to the local memory of the Compute tile from a Shim tile. The size of the input data `N` from the Shim tile is configurable (default: `1024xi32` for the non-JIT version, customizable via command-line arguments for the JIT version). The data is copied to the AIE tile, where the reduction is performed. The single output data value is copied from the AIE tile to the Shim tile. Both approaches offer different compilation workflows with the JIT version adding microseconds runtime overhead.
1616

1717
## Source Files Overview
1818

19-
1. `vector_reduce_min.py`: A Python script that defines the AIE array structural design using MLIR-AIE operations. This generates MLIR that is then compiled using `aiecc.py` to produce design binaries (ie. XCLBIN and inst.bin for the NPU in Ryzen™ AI).
19+
### JIT Approach Files
2020

21-
1. `vector_reduce_min_placed.py`: An alternative version of the design in `vector_reduce_min.py`, that is expressed in a lower-level version of IRON.
21+
1. **`vector_reduce_min_jit.py`**: A JIT (Just-In-Time) compiled version using IRON's `@iron.jit` decorator. This approach offers faster development iteration by compiling and executing the design at runtime, with support for command-line arguments to customize the number of elements.
2222

23-
1. `reduce_min.cc`: A C++ implementation of a vectorized `min` reduction operation for AIE cores. The code uses the AIE API, which is a C++ header-only library providing types and operations that get translated into efficient low-level intrinsics, and whose documentation can be found [here](https://www.xilinx.com/htmldocs/xilinx2023_2/aiengine_api/aie_api/doc/index.html). The source can be found [here](../../../aie_kernels/aie2/reduce_min.cc).
23+
### Non-JIT Approach Files
2424

25-
1. `test.cpp`: This C++ code is a testbench for the design example targetting Ryzen™ AI (AIE2). The code is responsible for loading the compiled XCLBIN file, configuring the AIE module, providing input data, and executing the AIE design on the NPU. After executing, the program verifies the results.
25+
1. **`vector_reduce_min.py`**: A Python script that defines the AIE array structural design using MLIR-AIE operations. This generates MLIR that is then compiled using `aiecc.py` to produce design binaries (ie. XCLBIN and inst.bin for the NPU in Ryzen™ AI).
2626

27-
## Ryzen™ AI Usage
27+
1. **`vector_reduce_min_placed.py`**: An alternative version of the design in `vector_reduce_min.py`, that is expressed in a lower-level version of IRON.
2828

29-
### Compilation
29+
1. **`test.cpp`**: This C++ code is a testbench for the non-JIT design example targetting Ryzen™ AI (AIE2). The code is responsible for loading the compiled XCLBIN file, configuring the AIE module, providing input data, and executing the AIE design on the NPU. After executing, the program verifies the results.
30+
31+
### Shared Files
32+
33+
1. **`reduce_min.cc`**: A C++ implementation of a vectorized `min` reduction operation for AIE cores. The code uses the AIE API, which is a C++ header-only library providing types and operations that get translated into efficient low-level intrinsics, and whose documentation can be found [here](https://www.xilinx.com/htmldocs/xilinx2023_2/aiengine_api/aie_api/doc/index.html). The source can be found [here](../../../aie_kernels/aie2/reduce_min.cc).
34+
35+
## Usage
36+
37+
### JIT Approach (Just-In-Time Compilation)
38+
39+
The JIT approach uses IRON's `@iron.jit` decorator for runtime compilation, offering faster development iteration and more flexible parameterization.
40+
41+
#### Running the JIT Version
42+
43+
To run the JIT version with default parameters (1024 elements):
44+
```shell
45+
python vector_reduce_min_jit.py
46+
```
47+
48+
To run with custom number of elements:
49+
```shell
50+
python vector_reduce_min_jit.py --num-elements 2048
51+
```
52+
53+
Or using the short form:
54+
```shell
55+
python vector_reduce_min_jit.py -n 512
56+
```
57+
58+
### Non-JIT Approach
59+
60+
The non-JIT approach uses traditional MLIR-AIE compilation where the design is compiled ahead-of-time to produce binaries.
61+
62+
#### Compilation
3063

3164
To compile the design:
3265
```shell
@@ -43,11 +76,26 @@ To compile the C++ testbench:
4376
make vector_reduce_min.exe
4477
```
4578

46-
### C++ Testbench
79+
#### C++ Testbench
4780

4881
To run the design:
49-
5082
```shell
5183
make run
5284
```
5385

86+
#### JIT vs Non-JIT Comparison
87+
88+
| Aspect | Non-JIT Approach | JIT Approach |
89+
|--------|------------------|--------------|
90+
| **Compilation** | Ahead-of-time via `aiecc.py` | Runtime compilation |
91+
| **Development Speed** | Slower (manual make/compilation) | Faster (compilation integrated) |
92+
| **Host Code** | C++ testbench (`test.cpp`) | Python script |
93+
| **Performance** | Baseline execution time | Microseconds overhead from JIT runtime |
94+
| **Flexibility** | Fixed at compile time | Runtime parameterization |
95+
| **Use Case** | Explicit XCLBIN management | Dynamic compilation |
96+
| **Binary Output** | Generates XCLBIN/inst.bin | Cached binaries in `IRON_CACHE_HOME` (defaults to `~/.iron/`) |
97+
98+
**When to use each approach:**
99+
- **Use JIT** for rapid prototyping, experimentation, runtime flexibility, and when you don't need control over XCLBINs
100+
- **Use non-JIT** when you need explicit XCLBIN control, working with existing MLIR-AIE workflows, or distributing pre-compiled binaries
101+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// (c) Copyright 2025 Advanced Micro Devices, Inc.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
//
4+
// REQUIRES: ryzen_ai, peano
5+
//
6+
// RUN: %run_on_npu1% python3 %S/vector_reduce_min_jit.py
7+
// RUN: %run_on_npu2% python3 %S/vector_reduce_min_jit.py

programming_examples/basic/vector_reduce_min/vector_reduce_min.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
#
7-
# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
7+
# (c) Copyright 2024-2025 Advanced Micro Devices, Inc. or its affiliates
88
import numpy as np
99
import sys
1010

@@ -35,20 +35,20 @@ def my_reduce_min():
3535
of_out = ObjectFifo(out_ty, name="out")
3636

3737
# AIE Core Function declarations
38-
reduce_add_vector = Kernel(
38+
reduce_min_vector = Kernel(
3939
"reduce_min_vector", "reduce_min.cc.o", [in_ty, out_ty, np.int32]
4040
)
4141

4242
# Define a task
43-
def core_body(of_in, of_out, reduce_add_vector):
43+
def core_body(of_in, of_out, reduce_min_vector):
4444
elem_out = of_out.acquire(1)
4545
elem_in = of_in.acquire(1)
46-
reduce_add_vector(elem_in, elem_out, N)
46+
reduce_min_vector(elem_in, elem_out, N)
4747
of_in.release(1)
4848
of_out.release(1)
4949

5050
# Define a worker to run the task on a core
51-
worker = Worker(core_body, fn_args=[of_in.cons(), of_out.prod(), reduce_add_vector])
51+
worker = Worker(core_body, fn_args=[of_in.cons(), of_out.prod(), reduce_min_vector])
5252

5353
# Runtime operations to move data to/from the AIE-array
5454
rt = Runtime()
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# vector_reduce_min/vector_reduce_min_jit.py -*- Python -*-
2+
#
3+
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#
7+
# (c) Copyright 2025 Advanced Micro Devices, Inc. or its affiliates
8+
import numpy as np
9+
import sys
10+
import os
11+
import argparse
12+
import time
13+
14+
import aie.iron as iron
15+
from aie.iron import ObjectFifo, Program, Runtime, Worker
16+
from aie.iron.placers import SequentialPlacer
17+
from aie.iron import ExternalFunction
18+
19+
20+
@iron.jit(is_placed=False)
21+
def my_reduce_min(input_tensor, output_tensor):
22+
23+
num_elements = input_tensor.numel()
24+
assert output_tensor.numel() == 1, "Output tensor must be a scalar"
25+
26+
# Define tensor types
27+
in_ty = np.ndarray[(num_elements,), np.dtype[input_tensor.dtype]]
28+
out_ty = np.ndarray[(1,), np.dtype[output_tensor.dtype]]
29+
30+
# AIE-array data movement with object fifos
31+
of_in = ObjectFifo(in_ty, name="in")
32+
of_out = ObjectFifo(out_ty, name="out")
33+
34+
# AIE Core Function declarations
35+
root_dir = os.path.abspath(os.path.join(__file__, "../../../.."))
36+
kernel_dir = os.path.join(root_dir, "aie_kernels/aie2")
37+
source_file = os.path.join(kernel_dir, "reduce_min.cc")
38+
reduce_min_vector = ExternalFunction(
39+
"reduce_min_vector",
40+
source_file=source_file,
41+
arg_types=[in_ty, out_ty, np.int32],
42+
include_dirs=[kernel_dir],
43+
)
44+
45+
# Define a task
46+
def core_body(of_in, of_out, reduce_min_vector):
47+
elem_out = of_out.acquire(1)
48+
elem_in = of_in.acquire(1)
49+
reduce_min_vector(elem_in, elem_out, num_elements)
50+
of_in.release(1)
51+
of_out.release(1)
52+
53+
# Define a worker to run the task on a core
54+
worker = Worker(core_body, fn_args=[of_in.cons(), of_out.prod(), reduce_min_vector])
55+
56+
# Runtime operations to move data to/from the AIE-array
57+
rt = Runtime()
58+
with rt.sequence(in_ty, out_ty) as (a_in, c_out):
59+
rt.start(worker)
60+
rt.fill(of_in.prod(), a_in)
61+
rt.drain(of_out.cons(), c_out, wait=True)
62+
63+
# Place program components (assign them resources on the device) and generate an MLIR module
64+
return Program(iron.get_current_device(), rt).resolve_program(SequentialPlacer())
65+
66+
67+
def main():
68+
69+
parser = argparse.ArgumentParser()
70+
parser.add_argument(
71+
"-n",
72+
"--num-elements",
73+
type=int,
74+
default=2048,
75+
help="Number of elements (default: 2048)",
76+
)
77+
parser.add_argument(
78+
"-w",
79+
"--warmup",
80+
type=int,
81+
default=10,
82+
help="Number of warmup iterations (default: 10)",
83+
)
84+
parser.add_argument(
85+
"-i",
86+
"--iters",
87+
type=int,
88+
default=20,
89+
help="Number of measurement iterations (default: 20)",
90+
)
91+
92+
args = parser.parse_args()
93+
num_elements = args.num_elements
94+
n_warmup_iterations = args.warmup
95+
n_iterations = args.iters
96+
data_type = np.int32
97+
98+
# Construct input and output tensors that are accessible to the NPU
99+
input_tensor = iron.randint(10, 100, (num_elements,), dtype=data_type, device="npu")
100+
output_tensor = iron.tensor((1,), dtype=data_type, device="npu")
101+
102+
# Initialize timing variables
103+
npu_time_total = 0.0
104+
npu_time_min = float("inf")
105+
npu_time_max = 0.0
106+
107+
# Main run loop with warmup and measurement iterations
108+
total_iterations = n_warmup_iterations + n_iterations
109+
for iter_num in range(total_iterations):
110+
# Launch the kernel and measure execution time
111+
start_time = time.perf_counter()
112+
my_reduce_min(input_tensor, output_tensor)
113+
end_time = time.perf_counter()
114+
115+
# Calculate execution time in microseconds
116+
execution_time_us = (end_time - start_time) * 1_000_000
117+
118+
# Skip warmup iterations for timing statistics
119+
if iter_num >= n_warmup_iterations:
120+
npu_time_total += execution_time_us
121+
npu_time_min = min(npu_time_min, execution_time_us)
122+
npu_time_max = max(npu_time_max, execution_time_us)
123+
124+
# Check the correctness of the result
125+
computed = output_tensor.numpy()[0]
126+
expected = input_tensor.numpy().min()
127+
128+
if expected == computed:
129+
# Print timing results
130+
if n_iterations > 1:
131+
avg_time = npu_time_total / n_iterations
132+
print(f"\nAvg NPU time: {avg_time:.1f}us.")
133+
print(f"Min NPU time: {npu_time_min:.1f}us.")
134+
print(f"Max NPU time: {npu_time_max:.1f}us.")
135+
else:
136+
print(f"\nNPU time: {npu_time_total:.1f}us.")
137+
print("PASS!")
138+
sys.exit(0)
139+
else:
140+
print(f"FAIL!: Expected {expected} but got {computed}")
141+
sys.exit(1)
142+
143+
144+
if __name__ == "__main__":
145+
main()

0 commit comments

Comments
 (0)