Skip to content

Commit 55385a4

Browse files
committed
Add validation and unit tests
1 parent 65fb61b commit 55385a4

File tree

2 files changed

+124
-19
lines changed

2 files changed

+124
-19
lines changed

onnxruntime/core/providers/cpu/tensor/affine_grid.cc

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include "core/common/common.h"
77
#include "core/providers/op_kernel_type_control.h"
88
#include "core/util/math_cpuonly.h"
9-
#include <iostream>
10-
#include "Eigen/src/Core/Map.h"
119
#include <Eigen/Dense>
1210
#include "core/common/eigen_common_wrapper.h"
1311

@@ -28,13 +26,14 @@ REGISTER_KERNEL_TYPED(double)
2826

2927
template <typename T>
3028
void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix<T, Eigen::Dynamic, 2>& base_grid) {
31-
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
29+
using VectorT = Eigen::Matrix<T, Eigen::Dynamic, 1>;
30+
VectorT row_vec = VectorT::LinSpaced(static_cast<Eigen::Index>(W), static_cast<T>(-1), static_cast<T>(1));
3231
if (!align_corners) {
33-
row_vec = row_vec * (W - 1) / W;
32+
row_vec = row_vec * static_cast<T>(W - 1) / static_cast<T>(W);
3433
}
35-
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
34+
VectorT col_vec = VectorT::LinSpaced(static_cast<Eigen::Index>(H), static_cast<T>(-1), static_cast<T>(1));
3635
if (!align_corners) {
37-
col_vec = col_vec * (H - 1) / H;
36+
col_vec = col_vec * static_cast<T>(H - 1) / static_cast<T>(H);
3837
}
3938

4039
base_grid.resize(static_cast<Eigen::Index>(H * W), 2);
@@ -47,17 +46,18 @@ void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matr
4746

4847
template <typename T>
4948
void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix<T, Eigen::Dynamic, 3>& base_grid) {
50-
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
49+
using VectorT = Eigen::Matrix<T, Eigen::Dynamic, 1>;
50+
VectorT row_vec = VectorT::LinSpaced(static_cast<Eigen::Index>(W), static_cast<T>(-1), static_cast<T>(1));
5151
if (!align_corners) {
52-
row_vec = row_vec * (W - 1) / W;
52+
row_vec = row_vec * static_cast<T>(W - 1) / static_cast<T>(W);
5353
}
54-
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
54+
VectorT col_vec = VectorT::LinSpaced(static_cast<Eigen::Index>(H), static_cast<T>(-1), static_cast<T>(1));
5555
if (!align_corners) {
56-
col_vec = col_vec * (H - 1) / H;
56+
col_vec = col_vec * static_cast<T>(H - 1) / static_cast<T>(H);
5757
}
58-
Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(D), -1, 1);
58+
VectorT slice_vec = VectorT::LinSpaced(static_cast<Eigen::Index>(D), static_cast<T>(-1), static_cast<T>(1));
5959
if (!align_corners) {
60-
slice_vec = slice_vec * (D - 1) / D;
60+
slice_vec = slice_vec * static_cast<T>(D - 1) / static_cast<T>(D);
6161
}
6262

6363
base_grid.resize(static_cast<Eigen::Index>(D * H * W), 3);
@@ -75,25 +75,27 @@ void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix<T, 2, Eig
7575
const Eigen::StorageOptions option = Eigen::RowMajor;
7676
auto theta_batch_offset = batch_num * 2 * 3;
7777
const T* theta_data = theta->Data<T>() + theta_batch_offset;
78-
const Eigen::Matrix<T, 2, 2, option> theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}};
79-
const Eigen::Array<T, 2, 1> theta_T(theta_data[2], theta_data[5]);
78+
const Eigen::Matrix<T, 2, 2, option> theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; // 2x2
79+
const Eigen::Array<T, 2, 1> theta_T(theta_data[2], theta_data[5]); // 2x1
8080

8181
auto grid_batch_offset = batch_num * H * W * 2;
8282
T* grid_data = grid->MutableData<T>() + grid_batch_offset;
8383
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 2, option>> grid_matrix(grid_data, narrow<size_t>(H * W), 2);
84-
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
84+
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); //(2x2 * 2xN).colwise() + 2x1).transpose()
8585
}
8686

8787
template <typename T>
8888
void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix<T, 3, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) {
8989
const Eigen::StorageOptions option = Eigen::RowMajor;
9090
auto theta_batch_offset = batch_num * 3 * 4;
9191
const T* theta_data = theta->Data<T>() + theta_batch_offset;
92+
9293
const Eigen::Matrix<T, 3, 3, option> theta_R{
9394
{theta_data[0], theta_data[1], theta_data[2]},
9495
{theta_data[4], theta_data[5], theta_data[6]},
95-
{theta_data[8], theta_data[9], theta_data[10]}};
96-
const Eigen::Array<T, 3, 1> theta_T(theta_data[3], theta_data[7], theta_data[11]);
96+
{theta_data[8], theta_data[9], theta_data[10]}}; // 3x3
97+
98+
const Eigen::Array<T, 3, 1> theta_T(theta_data[3], theta_data[7], theta_data[11]); // 3x1
9799

98100
auto grid_batch_offset = batch_num * D * H * W * 3;
99101
T* grid_data = grid->MutableData<T>() + grid_batch_offset;
@@ -113,9 +115,15 @@ Status AffineGrid<T>::Compute(OpKernelContext* context) const {
113115
const TensorShape& size_shape = size->Shape();
114116
const int64_t* size_data = size->Data<int64_t>();
115117

116-
if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {
118+
if (size_shape.GetDims()[0] == 4) {
117119
int64_t N = size_data[0], H = size_data[2], W = size_data[3];
118120

121+
ORT_RETURN_IF(N != theta_shape[0],
122+
"AffineGrid: size[0] (", N, ") must equal theta batch dimension (", theta_shape[0], ")");
123+
ORT_RETURN_IF(theta_shape[1] != 2 || theta_shape[2] != 3,
124+
"AffineGrid: theta shape must be [N, 2, 3] for 2D, got [",
125+
theta_shape[0], ", ", theta_shape[1], ", ", theta_shape[2], "]");
126+
119127
TensorShape grid_shape{N, H, W, 2};
120128
auto grid = context->Output(0, grid_shape);
121129

@@ -128,9 +136,15 @@ Status AffineGrid<T>::Compute(OpKernelContext* context) const {
128136
};
129137

130138
concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(N), std::move(fn), 0);
131-
} else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {
139+
} else if (size_shape.GetDims()[0] == 5) {
132140
int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4];
133141

142+
ORT_RETURN_IF(N != theta_shape[0],
143+
"AffineGrid: size[0] (", N, ") must equal theta batch dimension (", theta_shape[0], ")");
144+
ORT_RETURN_IF(theta_shape[1] != 3 || theta_shape[2] != 4,
145+
"AffineGrid: theta shape must be [N, 3, 4] for 3D, got [",
146+
theta_shape[0], ", ", theta_shape[1], ", ", theta_shape[2], "]");
147+
134148
TensorShape grid_shape{N, D, H, W, 3};
135149
auto grid = context->Output(0, grid_shape);
136150

onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/util/math.h"
55
#include "gtest/gtest.h"
66
#include "test/providers/provider_test_utils.h"
7+
#include "test/unittest_util/op_tester.h" // Ensure this include is present
78

89
namespace onnxruntime {
910
namespace test {
@@ -178,5 +179,95 @@ TEST(AffineGridTest, test_3d_7) {
178179
test.SetOutputTolerance(0.0001f);
179180
test.Run();
180181
}
182+
183+
// Validation tests for input shape checks
184+
185+
// Test: theta must be a 3D tensor
186+
TEST(AffineGridTest, invalid_theta_not_3d) {
187+
OpTester test("AffineGrid", 20);
188+
// theta is 2D instead of 3D
189+
test.AddInput<float>("theta", {2, 6}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f});
190+
test.AddInput<int64_t>("size", {4}, {2, 1, 2, 3});
191+
test.AddOutput<float>("grid", {2, 2, 3, 2}, std::vector<float>(24, 0.0f));
192+
test.Run(OpTester::ExpectResult::kExpectFailure, "AffineGrid : Input theta tensor dimension is not 3");
193+
}
194+
195+
// Test: size length must be 4 or 5
196+
TEST(AffineGridTest, invalid_size_length_3) {
197+
OpTester test("AffineGrid", 20);
198+
test.AddInput<float>("theta", {1, 2, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f});
199+
test.AddInput<int64_t>("size", {3}, {1, 1, 2});
200+
test.AddOutput<float>("grid", {1, 2, 2}, std::vector<float>(4, 0.0f));
201+
test.Run(OpTester::ExpectResult::kExpectFailure, "Length of input 'size' is 3. It must be 4 for 2D or 5 for 5D");
202+
}
203+
204+
// Test: size length must be 4 or 5 (too long)
205+
TEST(AffineGridTest, invalid_size_length_6) {
206+
OpTester test("AffineGrid", 20);
207+
test.AddInput<float>("theta", {1, 2, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f});
208+
test.AddInput<int64_t>("size", {6}, {1, 1, 2, 3, 4, 5});
209+
test.AddOutput<float>("grid", {1, 2, 3, 2}, std::vector<float>(12, 0.0f));
210+
test.Run(OpTester::ExpectResult::kExpectFailure, "Length of input 'size' is 6. It must be 4 for 2D or 5 for 5D");
211+
}
212+
213+
// Test: 2D - batch dimension mismatch between theta and size
214+
TEST(AffineGridTest, invalid_2d_batch_mismatch) {
215+
OpTester test("AffineGrid", 20);
216+
// theta has N=1, but size has N=2
217+
test.AddInput<float>("theta", {1, 2, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f});
218+
test.AddInput<int64_t>("size", {4}, {2, 1, 2, 3});
219+
test.AddOutput<float>("grid", {2, 2, 3, 2}, std::vector<float>(24, 0.0f));
220+
test.Run(OpTester::ExpectResult::kExpectFailure, "must equal theta batch dimension");
221+
}
222+
223+
// Test: 2D - theta shape must be [N, 2, 3], wrong second dimension
224+
TEST(AffineGridTest, invalid_2d_theta_wrong_dim1) {
225+
OpTester test("AffineGrid", 20);
226+
// theta is [1, 3, 3] but for 2D it must be [N, 2, 3]
227+
test.AddInput<float>("theta", {1, 3, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f});
228+
test.AddInput<int64_t>("size", {4}, {1, 1, 2, 3});
229+
test.AddOutput<float>("grid", {1, 2, 3, 2}, std::vector<float>(12, 0.0f));
230+
test.Run(OpTester::ExpectResult::kExpectFailure, "theta shape must be [N, 2, 3] for 2D");
231+
}
232+
233+
// Test: 2D - theta shape must be [N, 2, 3], wrong third dimension
234+
TEST(AffineGridTest, invalid_2d_theta_wrong_dim2) {
235+
OpTester test("AffineGrid", 20);
236+
// theta is [1, 2, 4] but for 2D it must be [N, 2, 3]
237+
test.AddInput<float>("theta", {1, 2, 4}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f});
238+
test.AddInput<int64_t>("size", {4}, {1, 1, 2, 3});
239+
test.AddOutput<float>("grid", {1, 2, 3, 2}, std::vector<float>(12, 0.0f));
240+
test.Run(OpTester::ExpectResult::kExpectFailure, "theta shape must be [N, 2, 3] for 2D");
241+
}
242+
243+
// Test: 3D - batch dimension mismatch between theta and size
244+
TEST(AffineGridTest, invalid_3d_batch_mismatch) {
245+
OpTester test("AffineGrid", 20);
246+
// theta has N=1, but size has N=2
247+
test.AddInput<float>("theta", {1, 3, 4}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f});
248+
test.AddInput<int64_t>("size", {5}, {2, 1, 2, 2, 3});
249+
test.AddOutput<float>("grid", {2, 2, 2, 3, 3}, std::vector<float>(72, 0.0f));
250+
test.Run(OpTester::ExpectResult::kExpectFailure, "must equal theta batch dimension");
251+
}
252+
253+
// Test: 3D - theta shape must be [N, 3, 4], wrong second dimension
254+
TEST(AffineGridTest, invalid_3d_theta_wrong_dim1) {
255+
OpTester test("AffineGrid", 20);
256+
// theta is [1, 2, 4] but for 3D it must be [N, 3, 4]
257+
test.AddInput<float>("theta", {1, 2, 4}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f});
258+
test.AddInput<int64_t>("size", {5}, {1, 1, 2, 2, 3});
259+
test.AddOutput<float>("grid", {1, 2, 2, 3, 3}, std::vector<float>(36, 0.0f));
260+
test.Run(OpTester::ExpectResult::kExpectFailure, "theta shape must be [N, 3, 4] for 3D");
261+
}
262+
263+
// Test: 3D - theta shape must be [N, 3, 4], wrong third dimension
264+
TEST(AffineGridTest, invalid_3d_theta_wrong_dim2) {
265+
OpTester test("AffineGrid", 20);
266+
// theta is [1, 3, 3] but for 3D it must be [N, 3, 4]
267+
test.AddInput<float>("theta", {1, 3, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f});
268+
test.AddInput<int64_t>("size", {5}, {1, 1, 2, 2, 3});
269+
test.AddOutput<float>("grid", {1, 2, 2, 3, 3}, std::vector<float>(36, 0.0f));
270+
test.Run(OpTester::ExpectResult::kExpectFailure, "theta shape must be [N, 3, 4] for 3D");
271+
}
181272
} // namespace test
182273
} // namespace onnxruntime

0 commit comments

Comments
 (0)