Skip to content

Commit 2e4f6d2

Browse files
authored
Add NonZero ONNX operation support (tracel-ai#3745)
* Add NonZero ONNX operation support Implements NonZero operation for ONNX import with comprehensive test coverage. - Add ONNX IR layer with rank inference (nonzero.rs) - Add Burn node implementation using argwhere() and not_equal_elem() - Add conversion integration in to_burn.rs - Add 6 test models covering Float32, Int64, Bool, 1D, 3D, and empty cases - Add integration tests for all tensor types and edge cases - Update SUPPORTED-ONNX-OPS.md to mark NonZero as supported All tests pass (6 integration + 4 unit tests). * Fix formatting * Fix NonZero output format to match ONNX specification Address PR feedback by correcting output format from [num_nonzero, rank] to [rank, num_nonzero] as required by ONNX spec. Changes: - Transpose argwhere() result in Burn node implementation - Update all test expectations to match correct format - Add ReferenceEvaluator validation to Python test generation - Update ONNX model shape specifications - Update comments to reflect correct format All tests pass (13 total: 4 unit + 6 integration + 3 codegen tests). Validated against ONNX ReferenceEvaluator.
1 parent bef0beb commit 2e4f6d2

File tree

19 files changed

+752
-14
lines changed

19 files changed

+752
-14
lines changed

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ functionality.
124124
| [Neg][109] |||
125125
| [NegativeLogLikelihoodLoss][110] |||
126126
| [NonMaxSuppression][112] |||
127-
| [NonZero][113] | ||
127+
| [NonZero][113] | ||
128128
| [Not][114] |||
129129
| [OneHot][115] |||
130130
| [Optional][116] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ fn main() {
123123
.input("tests/eye_like/eye_like_1x1.onnx")
124124
.input("tests/eye_like/eye_like_wide.onnx")
125125
.input("tests/eye_like/eye_like_neg_large_k.onnx")
126+
.input("tests/nonzero/nonzero_float32.onnx")
127+
.input("tests/nonzero/nonzero_int64.onnx")
128+
.input("tests/nonzero/nonzero_bool.onnx")
129+
.input("tests/nonzero/nonzero_1d.onnx")
130+
.input("tests/nonzero/nonzero_3d.onnx")
131+
.input("tests/nonzero/nonzero_empty.onnx")
126132
.input("tests/flatten/flatten.onnx")
127133
.input("tests/flatten/flatten_2d.onnx")
128134
.input("tests/floor/floor.onnx")
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use crate::include_models;
2+
include_models!(
3+
nonzero_float32,
4+
nonzero_int64,
5+
nonzero_bool,
6+
nonzero_1d,
7+
nonzero_3d,
8+
nonzero_empty
9+
);
10+
11+
#[cfg(test)]
12+
mod tests {
13+
use super::*;
14+
use burn::tensor::{Bool, Int, Tensor};
15+
16+
use crate::backend::TestBackend;
17+
18+
#[test]
19+
fn nonzero_float32_test() {
20+
let device = Default::default();
21+
let model = nonzero_float32::Model::<TestBackend>::new(&device);
22+
23+
// Create a 3x4 tensor with some non-zero values
24+
// Expected nonzero indices: (0,1), (1,2), (2,0), (2,3)
25+
let input = Tensor::<TestBackend, 2>::from_floats(
26+
[
27+
[0.0, 1.0, 0.0, 0.0],
28+
[0.0, 0.0, 2.5, 0.0],
29+
[3.1, 0.0, 0.0, -1.2],
30+
],
31+
&device,
32+
);
33+
34+
let output = model.forward(input);
35+
36+
// Expected indices in ONNX format [rank, num_nonzero]: [[0,1,2,2], [1,2,0,3]]
37+
let expected = Tensor::<TestBackend, 2, Int>::from_ints(
38+
[
39+
[0, 1, 2, 2], // Row indices of non-zero elements
40+
[1, 2, 0, 3], // Column indices of non-zero elements
41+
],
42+
&device,
43+
);
44+
45+
output.to_data().assert_eq(&expected.to_data(), true);
46+
}
47+
48+
#[test]
49+
fn nonzero_int64_test() {
50+
let device = Default::default();
51+
let model = nonzero_int64::Model::<TestBackend>::new(&device);
52+
53+
// Create a 2x3 int tensor with some non-zero values
54+
// Expected nonzero indices: (0,0), (1,2)
55+
let input = Tensor::<TestBackend, 2, Int>::from_ints([[5, 0, 0], [0, 0, -3]], &device);
56+
57+
let output = model.forward(input);
58+
59+
// Expected indices in ONNX format [rank, num_nonzero]: [[0,1], [0,2]]
60+
let expected = Tensor::<TestBackend, 2, Int>::from_ints(
61+
[
62+
[0, 1], // Row indices of non-zero elements
63+
[0, 2], // Column indices of non-zero elements
64+
],
65+
&device,
66+
);
67+
68+
output.to_data().assert_eq(&expected.to_data(), true);
69+
}
70+
71+
#[test]
72+
fn nonzero_bool_test() {
73+
let device = Default::default();
74+
let model = nonzero_bool::Model::<TestBackend>::new(&device);
75+
76+
// Create a 2x2 bool tensor
77+
// Expected nonzero indices: (0,1), (1,0)
78+
let input = Tensor::<TestBackend, 2, Bool>::from_bool(
79+
[[false, true], [true, false]].into(),
80+
&device,
81+
);
82+
83+
let output = model.forward(input);
84+
85+
// Expected indices in ONNX format [rank, num_nonzero]: [[0,1], [1,0]]
86+
let expected = Tensor::<TestBackend, 2, Int>::from_ints(
87+
[
88+
[0, 1], // Row indices of non-zero elements
89+
[1, 0], // Column indices of non-zero elements
90+
],
91+
&device,
92+
);
93+
94+
output.to_data().assert_eq(&expected.to_data(), true);
95+
}
96+
97+
#[test]
98+
fn nonzero_1d_test() {
99+
let device = Default::default();
100+
let model = nonzero_1d::Model::<TestBackend>::new(&device);
101+
102+
// Create a 1D tensor with some non-zero values
103+
// Expected nonzero indices: 1, 3, 5
104+
let input = Tensor::<TestBackend, 1>::from_floats([0.0, 2.0, 0.0, -1.0, 0.0, 3.5], &device);
105+
106+
let output = model.forward(input);
107+
108+
// Expected indices in ONNX format [rank, num_nonzero]: [[1, 3, 5]]
109+
let expected = Tensor::<TestBackend, 2, Int>::from_ints([[1, 3, 5]], &device);
110+
111+
output.to_data().assert_eq(&expected.to_data(), true);
112+
}
113+
114+
#[test]
115+
fn nonzero_3d_test() {
116+
let device = Default::default();
117+
let model = nonzero_3d::Model::<TestBackend>::new(&device);
118+
119+
// Create a 2x2x3 tensor with a few non-zero values
120+
// Expected nonzero indices: (0,0,1), (1,1,2)
121+
let input = Tensor::<TestBackend, 3>::from_floats(
122+
[
123+
[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
124+
[[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]],
125+
],
126+
&device,
127+
);
128+
129+
let output = model.forward(input);
130+
131+
// Expected indices in ONNX format [rank, num_nonzero]: [[0,1], [0,1], [1,2]]
132+
let expected = Tensor::<TestBackend, 2, Int>::from_ints(
133+
[
134+
[0, 1], // Dimension 0 indices of non-zero elements
135+
[0, 1], // Dimension 1 indices of non-zero elements
136+
[1, 2], // Dimension 2 indices of non-zero elements
137+
],
138+
&device,
139+
);
140+
141+
output.to_data().assert_eq(&expected.to_data(), true);
142+
}
143+
144+
#[test]
145+
fn nonzero_empty_test() {
146+
let device = Default::default();
147+
let model = nonzero_empty::Model::<TestBackend>::new(&device);
148+
149+
// Create a tensor with all zeros
150+
let input =
151+
Tensor::<TestBackend, 2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device);
152+
153+
let output = model.forward(input);
154+
155+
// Empty tensor: [2, 0] - 2 dimensions, 0 non-zero elements
156+
let expected = Tensor::<TestBackend, 2, Int>::empty([2, 0], &device);
157+
158+
output.to_data().assert_eq(&expected.to_data(), true);
159+
}
160+
}

0 commit comments

Comments
 (0)