Skip to content

Commit 4002ecc

Browse files
authored
trunc, fmod and Mod ONNX ops (tracel-ai#3767)
* Add trunc operation for float tensors Introduces a new `trunc` method for float tensors, which truncates each element toward zero. Updates documentation and test suites to include and verify the new operation. * Add fmod operation for float tensors Introduces the fmod and fmod_scalar methods to compute floating-point remainders for float tensors. Updates documentation and adds comprehensive tests to verify correct behavior for various input scenarios. * Add ONNX Mod operator support and tests Implements support for the ONNX Mod operator, including both fmod (C-style) and remainder (Python-style) behaviors. Adds ModNode and related codegen, updates node registration, conversion, and rank inference logic. Introduces comprehensive ONNX test models and Rust test cases for tensor, scalar, and broadcasting scenarios. * Handle IEEE 754 edge cases in fmod tensor ops Improves the fmod and fmod_scalar tensor operations to correctly handle IEEE 754 special cases, including NaN, infinity, and zero divisors. Adds comprehensive tests for these edge cases and broadcasting behavior to ensure compliance and correctness. * Fix formatting * Mark ONNX Mod op as supported Updated SUPPORTED-ONNX-OPS.md to indicate that the Mod operation is now supported for both import and export. * Refactor fmod handling and test assertions * Improve ONNX mod tests with ReferenceEvaluator checks * Fix zero times infinity check in fmod implementation * Improve trunc method to handle IEEE 754 special cases Updated the trunc implementation to correctly handle special cases per IEEE 754, including preserving the sign of zero, infinities, and NaN. Added corresponding tests to verify correct behavior for these cases.
1 parent 0ad6699 commit 4002ecc

36 files changed

+1690
-29
lines changed

burn-book/src/building-blocks/tensor.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ Those operations are only available for `Float` tensors.
275275
| `tensor.erf()` | `tensor.erf()` |
276276
| `tensor.exp()` | `tensor.exp()` |
277277
| `tensor.floor()` | `tensor.floor()` |
278+
| `tensor.fmod(other)` | `tensor.fmod(other)` |
279+
| `tensor.fmod_scalar(scalar)` | `tensor.fmod(scalar)` |
278280
| `tensor.from_floats(floats, device)` | N/A |
279281
| `tensor.from_full_precision(tensor)` | N/A |
280282
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
@@ -295,6 +297,7 @@ Those operations are only available for `Float` tensors.
295297
| `tensor.tan()` | `tensor.tan()` |
296298
| `tensor.tanh()` | `tensor.tanh()` |
297299
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
300+
| `tensor.trunc()` | `tensor.trunc()` |
298301
| `tensor.var(dim)` | `tensor.var(dim)` |
299302
| `tensor.var_bias(dim)` | N/A |
300303
| `tensor.var_mean(dim)` | N/A |

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ functionality.
118118
| [MelWeightMatrix][103] |||
119119
| [Min][104] |||
120120
| [Mish][105] |||
121-
| [Mod][106] | | |
121+
| [Mod][106] | | |
122122
| [Mul][107] |||
123123
| [Multinomial][108] |||
124124
| [Neg][109] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ fn main() {
107107
.input("tests/div/div_shape.onnx")
108108
.input("tests/div/div_shape_tensor.onnx")
109109
.input("tests/div/div_broadcast.onnx")
110+
.input("tests/mod/modulo.onnx")
111+
.input("tests/mod/mod_scalar.onnx")
112+
.input("tests/mod/mod_remainder.onnx")
113+
.input("tests/mod/mod_fmod.onnx")
114+
.input("tests/mod/mod_broadcast_fixed.onnx")
115+
.input("tests/mod/mod_broadcast_remainder_fixed.onnx")
110116
.input("tests/dropout/dropout.onnx")
111117
.input("tests/equal/equal.onnx")
112118
.input("tests/equal/equal_shape.onnx")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/mod/mod.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Model(nn.Module):
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
13+
def forward(self, x, y):
14+
# Compute the modulo operation
15+
z = torch.fmod(x, y)
16+
return z
17+
18+
19+
def main():
20+
# Export to onnx
21+
model = Model()
22+
model.eval()
23+
device = torch.device("cpu")
24+
onnx_name = "mod.onnx"
25+
26+
# Create dummy inputs with proper shapes
27+
dummy_x = torch.randn(2, 3, 4, device=device)
28+
dummy_y = torch.randn(2, 3, 4, device=device)
29+
30+
torch.onnx.export(model, (dummy_x, dummy_y), onnx_name,
31+
verbose=False, opset_version=16)
32+
33+
print("Finished exporting model to {}".format(onnx_name))
34+
35+
# Output some test data for use in the test
36+
test_x = torch.tensor([[[[5.3, -5.3, 7.5, -7.5]]]])
37+
test_y = torch.tensor([[[[2.0, 2.0, 3.0, 3.0]]]])
38+
39+
print("Test input x: {}".format(test_x))
40+
print("Test input y: {}".format(test_y))
41+
output = model.forward(test_x, test_y)
42+
print("Test output: {}".format(output))
43+
44+
45+
if __name__ == '__main__':
46+
main()
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Import the shared macro
2+
use crate::include_models;
3+
include_models!(
4+
modulo,
5+
mod_scalar,
6+
mod_remainder,
7+
mod_fmod,
8+
mod_broadcast_fixed,
9+
mod_broadcast_remainder_fixed
10+
);
11+
12+
#[cfg(test)]
13+
mod tests {
14+
use super::*;
15+
use burn::tensor::{Tensor, TensorData};
16+
17+
use crate::backend::TestBackend;
18+
19+
#[test]
20+
fn mod_tensor_by_tensor() {
21+
// Initialize the model
22+
let device = Default::default();
23+
let model: modulo::Model<TestBackend> = modulo::Model::new(&device);
24+
25+
// Run the model
26+
let input_x = Tensor::<TestBackend, 3>::from_floats([[[5.3, -5.3, 7.5, -7.5]]], &device);
27+
let input_y = Tensor::<TestBackend, 3>::from_floats([[[2.0, 2.0, 3.0, 3.0]]], &device);
28+
let output = model.forward(input_x, input_y);
29+
30+
// Expected output: fmod(x, y) for each element
31+
// Using the actual computed values from Python
32+
let expected = TensorData::from([[[1.3000002f32, -1.3000002, 1.5, -1.5]]]);
33+
34+
output.to_data().assert_eq(&expected, true);
35+
}
36+
37+
#[test]
38+
fn mod_tensor_by_scalar() {
39+
// Initialize the model
40+
let device = Default::default();
41+
let model: mod_scalar::Model<TestBackend> = mod_scalar::Model::new(&device);
42+
43+
// Run the model
44+
let input_x = Tensor::<TestBackend, 4>::from_floats([[[[5.3, -5.3, 7.5, -7.5]]]], &device);
45+
let scalar = 2.0f64;
46+
let output = model.forward(input_x, scalar);
47+
48+
// Expected output: fmod(x, 2.0) for each element
49+
// Using the actual computed values from Python
50+
let expected = TensorData::from([[[[1.3000002f32, -1.3000002, 1.5, -1.5]]]]);
51+
52+
output.to_data().assert_eq(&expected, true);
53+
}
54+
55+
#[test]
56+
fn mod_remainder() {
57+
// Test fmod=0 (Python-style remainder)
58+
let device = Default::default();
59+
let model: mod_remainder::Model<TestBackend> = mod_remainder::Model::new(&device);
60+
61+
let input_x = Tensor::<TestBackend, 3>::from_floats([[[5.3, -5.3, 7.5, -7.5]]], &device);
62+
let input_y = Tensor::<TestBackend, 3>::from_floats([[[2.0, 2.0, 3.0, 3.0]]], &device);
63+
let output = model.forward(input_x, input_y);
64+
65+
// Expected: Python-style remainder where sign follows divisor
66+
// remainder(5.3, 2.0) = 1.3, remainder(-5.3, 2.0) = 0.7
67+
// remainder(7.5, 3.0) = 1.5, remainder(-7.5, 3.0) = 1.5
68+
let expected = TensorData::from([[[1.3000002f32, 0.6999998, 1.5, 1.5]]]);
69+
output.to_data().assert_eq(&expected, true);
70+
}
71+
72+
#[test]
73+
fn mod_fmod() {
74+
// Test fmod=1 (C-style fmod)
75+
let device = Default::default();
76+
let model: mod_fmod::Model<TestBackend> = mod_fmod::Model::new(&device);
77+
78+
let input_x = Tensor::<TestBackend, 3>::from_floats([[[5.3, -5.3, 7.5, -7.5]]], &device);
79+
let input_y = Tensor::<TestBackend, 3>::from_floats([[[2.0, 2.0, 3.0, 3.0]]], &device);
80+
let output = model.forward(input_x, input_y);
81+
82+
// Expected: fmod operation where sign follows dividend
83+
let expected = TensorData::from([[[1.3000002f32, -1.3000002, 1.5, -1.5]]]);
84+
output.to_data().assert_eq(&expected, true);
85+
}
86+
87+
#[test]
88+
fn mod_broadcast() {
89+
// Test broadcasting with fmod=1
90+
let device = Default::default();
91+
let model: mod_broadcast_fixed::Model<TestBackend> =
92+
mod_broadcast_fixed::Model::new(&device);
93+
94+
let input_x = Tensor::<TestBackend, 2>::from_floats(
95+
[
96+
[5.0, -7.0, 8.0, -9.0],
97+
[4.0, -6.0, 10.0, -11.0],
98+
[3.0, -5.0, 12.0, -13.0],
99+
],
100+
&device,
101+
);
102+
103+
let input_y = Tensor::<TestBackend, 4>::from_floats(
104+
[
105+
[[
106+
[3.0, 3.0, 3.0, 3.0],
107+
[3.0, 3.0, 3.0, 3.0],
108+
[3.0, 3.0, 3.0, 3.0],
109+
]],
110+
[[
111+
[4.0, 4.0, 4.0, 4.0],
112+
[4.0, 4.0, 4.0, 4.0],
113+
[4.0, 4.0, 4.0, 4.0],
114+
]],
115+
],
116+
&device,
117+
);
118+
119+
let output = model.forward(input_x, input_y);
120+
121+
// Check shape and sample values
122+
assert_eq!(output.dims(), [2, 1, 3, 4]);
123+
124+
// Check first batch, first row
125+
let data = output.to_data();
126+
let values = data.as_slice::<f32>().unwrap();
127+
// fmod(5.0, 3.0) = 2.0, fmod(-7.0, 3.0) = -1.0, etc.
128+
assert!((values[0] - 2.0).abs() < 0.001);
129+
assert!((values[1] - (-1.0)).abs() < 0.001);
130+
assert!((values[2] - 2.0).abs() < 0.001);
131+
assert!((values[3] - 0.0).abs() < 0.001);
132+
}
133+
134+
#[test]
135+
fn mod_broadcast_remainder() {
136+
// Test broadcasting with fmod=0 (remainder)
137+
let device = Default::default();
138+
let model: mod_broadcast_remainder_fixed::Model<TestBackend> =
139+
mod_broadcast_remainder_fixed::Model::new(&device);
140+
141+
let input_x =
142+
Tensor::<TestBackend, 3>::from_floats([[[7.5], [-8.5], [9.5], [-10.5]]], &device);
143+
144+
let input_y = Tensor::<TestBackend, 3>::from_floats(
145+
[
146+
[[3.0, 4.0, -3.0, -4.0, 5.0]],
147+
[[3.0, 4.0, -3.0, -4.0, 5.0]],
148+
[[3.0, 4.0, -3.0, -4.0, 5.0]],
149+
],
150+
&device,
151+
);
152+
153+
let output = model.forward(input_x, input_y);
154+
155+
// Check shape
156+
assert_eq!(output.dims(), [3, 4, 5]);
157+
158+
// Check first row values for Python-style remainder
159+
// remainder(7.5, 3.0) = 1.5, remainder(7.5, 4.0) = 3.5
160+
// remainder(7.5, -3.0) = -1.5, remainder(7.5, -4.0) = -0.5
161+
// remainder(7.5, 5.0) = 2.5
162+
let data = output.to_data();
163+
let values = data.as_slice::<f32>().unwrap();
164+
assert!((values[0] - 1.5).abs() < 0.001);
165+
assert!((values[1] - 3.5).abs() < 0.001);
166+
assert!((values[2] - (-1.5)).abs() < 0.001);
167+
assert!((values[3] - (-0.5)).abs() < 0.001);
168+
assert!((values[4] - 2.5).abs() < 0.001);
169+
}
170+
}
152 Bytes
Binary file not shown.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/mod/mod_broadcast.onnx
4+
# Tests broadcasting with fmod=1
5+
6+
import numpy as np
7+
import onnx
8+
from onnx import helper, TensorProto
9+
10+
def main():
11+
# Create ONNX model with Mod operator using broadcasting
12+
# Different rank tensors: 2D and 4D
13+
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) # 2D
14+
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 1, 3, 4]) # 4D
15+
16+
# Output tensor will have the broadcasted shape
17+
z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [2, 1, 3, 4])
18+
19+
# Create Mod node with fmod=1 (C-style fmod)
20+
mod_node = helper.make_node(
21+
'Mod',
22+
inputs=['x', 'y'],
23+
outputs=['z'],
24+
fmod=1 # C-style fmod
25+
)
26+
27+
# Create the graph
28+
graph_def = helper.make_graph(
29+
[mod_node],
30+
'mod_broadcast_model',
31+
[x, y],
32+
[z],
33+
)
34+
35+
# Create the model with opset version 16
36+
model_def = helper.make_model(
37+
graph_def,
38+
producer_name='onnx-tests',
39+
opset_imports=[helper.make_operatorsetid("", 16)]
40+
)
41+
42+
# Save the model
43+
onnx_name = "mod_broadcast.onnx"
44+
onnx.save(model_def, onnx_name)
45+
onnx.checker.check_model(onnx_name)
46+
print(f"Finished exporting model to {onnx_name}")
47+
48+
# Test with onnx.reference.ReferenceEvaluator
49+
try:
50+
from onnx.reference import ReferenceEvaluator
51+
52+
# Create test data
53+
test_x = np.array([[5.0, -7.0, 8.0, -9.0],
54+
[4.0, -6.0, 10.0, -11.0],
55+
[3.0, -5.0, 12.0, -13.0]]).astype(np.float32)
56+
57+
test_y = np.array([[[[3.0, 3.0, 3.0, 3.0],
58+
[3.0, 3.0, 3.0, 3.0],
59+
[3.0, 3.0, 3.0, 3.0]]],
60+
[[[4.0, 4.0, 4.0, 4.0],
61+
[4.0, 4.0, 4.0, 4.0],
62+
[4.0, 4.0, 4.0, 4.0]]]]).astype(np.float32)
63+
64+
# Run inference with ReferenceEvaluator
65+
sess = ReferenceEvaluator(model_def)
66+
result = sess.run(None, {"x": test_x, "y": test_y})
67+
68+
print(f"Test input x shape: {test_x.shape}")
69+
print(f"Test input y shape: {test_y.shape}")
70+
print(f"Result shape: {result[0].shape}")
71+
print(f"Sample output values (result[0,0,0,:]): {result[0][0, 0, 0, :]}")
72+
73+
# Verify expected results for fmod operation
74+
test_x_broadcast = np.broadcast_to(test_x, test_y.shape)
75+
expected_result = np.fmod(test_x_broadcast, test_y)
76+
np.testing.assert_allclose(result[0], expected_result, rtol=1e-5)
77+
print("Test passed: Results match expected fmod values")
78+
79+
except ImportError:
80+
print("onnx.reference not available, skipping inference test")
81+
82+
if __name__ == '__main__':
83+
main()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
pytorch2.8.0:�
2+
5
3+
onnx::Mod_0
4+
onnx::Mod_12/Mod"Mod*
5+
fmod�
6+
main_graphZ
7+
onnx::Mod_0
8+

9+

10+
Z%
11+
onnx::Mod_1
12+

13+

14+

15+

16+
b
17+
2
18+

19+

20+

21+

22+
B

0 commit comments

Comments
 (0)