Skip to content

Commit b28b5db

Browse files
yuvaltassacopybara-github
authored andcommitted
Add mj_multiRayNormal for multi-ray casting with normal computation (not exposed in public header)
PiperOrigin-RevId: 847821078 Change-Id: I616441d3460406a92a10731d1a086af6163e96ad
1 parent 7fddeea commit b28b5db

File tree

3 files changed

+105
-32
lines changed

3 files changed

+105
-32
lines changed

src/engine/engine_ray.c

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -955,14 +955,6 @@ static mjtNum mj_raySdfNormal(const mjModel* m, const mjData* d, int g,
955955
return -1;
956956
}
957957

958-
959-
// intersect ray with signed distance field
960-
static mjtNum ray_sdf(const mjModel* m, const mjData* d, int g,
961-
const mjtNum pnt[3], const mjtNum vec[3]) {
962-
return mj_raySdfNormal(m, d, g, pnt, vec, NULL);
963-
}
964-
965-
966958
// intersect ray with mesh, compute normal if given
967959
static mjtNum mj_rayMeshNormal(const mjModel* m, const mjData* d, int id, const mjtNum pnt[3],
968960
const mjtNum vec[3], mjtNum normal[3]) {
@@ -1488,14 +1480,18 @@ void mju_multiRayPrepare(const mjModel* m, const mjData* d, const mjtNum pnt[3],
14881480
}
14891481

14901482

1491-
// Performs single ray intersection
1483+
// Performs single ray intersection, compute normal if given
14921484
static mjtNum mju_singleRay(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum vec[3],
1493-
int* ray_eliminate, mjtNum* geom_ba, int geomid[1]) {
1485+
int* ray_eliminate, mjtNum* geom_ba, int geomid[1],
1486+
mjtNum normal[3]) {
14941487
mjtNum dist, newdist;
1488+
mjtNum normal_local[3];
1489+
mjtNum* p_normal = normal ? normal_local : NULL;
14951490

14961491
// clear result
14971492
dist = -1;
14981493
*geomid = -1;
1494+
if (normal) mju_zero3(normal);
14991495

15001496
// get ray spherical coordinates
15011497
mjtNum azimuth = longitude(vec);
@@ -1530,25 +1526,24 @@ static mjtNum mju_singleRay(const mjModel* m, mjData* d, const mjtNum pnt[3], co
15301526
}
15311527
}
15321528

1533-
// handle mesh and hfield separately
1534-
if (m->geom_type[i] == mjGEOM_MESH) {
1535-
newdist = mj_rayMesh(m, d, i, pnt, vec);
1536-
} else if (m->geom_type[i] == mjGEOM_HFIELD) {
1537-
newdist = mj_rayHfield(m, d, i, pnt, vec);
1538-
} else if (m->geom_type[i] == mjGEOM_SDF) {
1539-
newdist = ray_sdf(m, d, i, pnt, vec);
1540-
}
1541-
1542-
// otherwise general dispatch
1543-
else {
1544-
newdist = mju_rayGeom(d->geom_xpos+3*i, d->geom_xmat+9*i,
1545-
m->geom_size+3*i, pnt, vec, m->geom_type[i]);
1529+
// dispatch to type-specific ray function
1530+
int type = m->geom_type[i];
1531+
if (type == mjGEOM_MESH) {
1532+
newdist = mj_rayMeshNormal(m, d, i, pnt, vec, p_normal);
1533+
} else if (type == mjGEOM_HFIELD) {
1534+
newdist = mj_rayHfieldNormal(m, d, i, pnt, vec, p_normal);
1535+
} else if (type == mjGEOM_SDF) {
1536+
newdist = mj_raySdfNormal(m, d, i, pnt, vec, p_normal);
1537+
} else {
1538+
newdist = mju_rayGeomNormal(d->geom_xpos+3*i, d->geom_xmat+9*i,
1539+
m->geom_size+3*i, pnt, vec, type, p_normal);
15461540
}
15471541

15481542
// update if closer intersection found
15491543
if (newdist >= 0 && (newdist < dist || dist < 0)) {
15501544
dist = newdist;
15511545
*geomid = i;
1546+
if (normal) mju_copy3(normal, normal_local);
15521547
}
15531548
}
15541549
}
@@ -1557,10 +1552,10 @@ static mjtNum mju_singleRay(const mjModel* m, mjData* d, const mjtNum pnt[3], co
15571552
}
15581553

15591554

1560-
// performs multiple ray intersections with the precomputed bv and flags
1561-
void mj_multiRay(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum* vec,
1562-
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
1563-
int* geomid, mjtNum* dist, int nray, mjtNum cutoff) {
1555+
// performs multiple ray intersections, compute normals if given
1556+
void mj_multiRayNormal(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum* vec,
1557+
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
1558+
int* geomid, mjtNum* dist, mjtNum* normal, int nray, mjtNum cutoff) {
15641559
mj_markStack(d);
15651560

15661561
// allocate source
@@ -1576,9 +1571,20 @@ void mj_multiRay(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum*
15761571
if (mju_dot3(vec+3*i, vec+3*i) < mjMINVAL) {
15771572
dist[i] = -1;
15781573
} else {
1579-
dist[i] = mju_singleRay(m, d, pnt, vec+3*i, geom_eliminate, geom_ba, geomid+i);
1574+
dist[i] = mju_singleRay(m, d, pnt, vec+3*i, geom_eliminate, geom_ba, geomid+i,
1575+
normal ? normal+3*i : NULL);
15801576
}
15811577
}
15821578

15831579
mj_freeStack(d);
15841580
}
1581+
1582+
1583+
// performs multiple ray intersections with the precomputed bv and flags
1584+
void mj_multiRay(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum vec[3],
1585+
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
1586+
int* geomid, mjtNum* dist, int nray, mjtNum cutoff) {
1587+
mj_multiRayNormal(m, d, pnt, vec, geomgroup, flg_static, bodyexclude,
1588+
geomid, dist, NULL, nray, cutoff);
1589+
}
1590+

src/engine/engine_ray.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@ MJAPI void mju_multiRayPrepare(const mjModel* m, const mjData* d,
3030
int* geom_eliminate);
3131

3232
// intersect multiple rays emanating from a single source
33-
// similar semantics to mj_ray, but vec is an array of (nray x 3) directions.
33+
// similar semantics to mj_ray, but vec is (nray x 3) and dist is (nray).
3434
MJAPI void mj_multiRay(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum* vec,
3535
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
3636
int* geomid, mjtNum* dist, int nray, mjtNum cutoff);
3737

38+
// intersect multiple rays, compute normals if given
39+
// similar semantics to mj_rayNormal, but vec, normal and dist are arrays.
40+
MJAPI void mj_multiRayNormal(const mjModel* m, mjData* d, const mjtNum pnt[3], const mjtNum* vec,
41+
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
42+
int* geomid, mjtNum* dist, mjtNum* normal, int nray, mjtNum cutoff);
43+
44+
3845
// intersect ray (pnt+x*vec, x>=0) with visible geoms, except geoms on bodyexclude
3946
// return geomid and distance (x) to nearest surface, or -1 if no intersection
4047
// geomgroup, flg_static are as in mjvOption; geomgroup==NULL skips group exclusion

test/engine/engine_ray_test.cc

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ TEST_F(RayTest, MultiRayEqualsSingleRay) {
187187
constexpr int N = 80;
188188
constexpr int M = 60;
189189
mjtNum vec[3*N*M];
190-
mjtNum pnt[3] = {1, 2, 3};
191-
mjtNum cone[4][3] = {{1, 1, -1}, {1, 1, 1}, {1, -1, -1}, {1, -1, 1}};
190+
mjtNum pnt[3] = {-1, 0, 0};
191+
mjtNum cone[4][3] = {{1, .2, -.2}, {1, .2, .2}, {1, -.2, -.2}, {1, -.2, .2}};
192192
memset(vec, 0, 3*N*M*sizeof(mjtNum));
193193

194194
for (int i = 0; i < N; ++i) {
@@ -211,15 +211,75 @@ TEST_F(RayTest, MultiRayEqualsSingleRay) {
211211
// compare results with single ray function
212212
mjtNum dist;
213213
int rgeomid;
214-
214+
int nhits = 0;
215215
for (int i = 0; i < N; ++i) {
216216
for (int j = 0; j < M; ++j) {
217217
int idx = i * M + j;
218218
dist = mj_ray(m, d, pnt, vec + 3 * idx, NULL, 1, -1, &rgeomid);
219219
EXPECT_FLOAT_EQ(dist, dist_multiray[idx]);
220220
EXPECT_EQ(rgeomid, rgeomid_multiray[idx]);
221+
nhits += dist >= 0;
222+
}
223+
}
224+
EXPECT_GT(nhits, 10);
225+
226+
mj_deleteData(d);
227+
mj_deleteModel(m);
228+
}
229+
230+
TEST_F(RayTest, MultiRayNormalEqualsSingleRayNormal) {
231+
char error[1024];
232+
mjModel* m = LoadModelFromString(kRayCastingModel, error, sizeof(error));
233+
ASSERT_THAT(m, NotNull()) << error;
234+
mjData* d = mj_makeData(m);
235+
ASSERT_THAT(d, NotNull());
236+
mj_forward(m, d);
237+
238+
// create ray array
239+
constexpr int N = 80;
240+
constexpr int M = 60;
241+
mjtNum vec[3*N*M];
242+
mjtNum pnt[3] = {-1, 0, 0};
243+
mjtNum cone[4][3] = {{1, .2, -.2}, {1, .2, .2}, {1, -.2, -.2}, {1, -.2, .2}};
244+
memset(vec, 0, 3*N*M*sizeof(mjtNum));
245+
246+
for (int i = 0; i < N; ++i) {
247+
for (int j = 0; j < M; ++j) {
248+
for (int k = 0; k < 3; ++k) {
249+
vec[3 * (i * M + j) + k] = i * cone[0][k] / (N - 1) +
250+
j * cone[1][1] / (M - 1) +
251+
(N - i - 1) * cone[2][k] / (N - 1) +
252+
(M - j - 1) * cone[3][k] / (M - 1);
253+
}
254+
}
255+
}
256+
257+
// compute intersections with multiray normal function
258+
mjtNum dist_multiray[N*M];
259+
int rgeomid_multiray[N*M];
260+
mjtNum normal_multiray[3*N*M];
261+
mj_multiRayNormal(m, d, pnt, vec, NULL, 1, -1, rgeomid_multiray,
262+
dist_multiray, normal_multiray, N * M, mjMAXVAL);
263+
264+
// compare results with single ray normal function
265+
mjtNum dist;
266+
int rgeomid;
267+
mjtNum normal[3];
268+
int nhits = 0;
269+
for (int i = 0; i < N; ++i) {
270+
for (int j = 0; j < M; ++j) {
271+
int idx = i * M + j;
272+
dist = mj_rayNormal(m, d, pnt, vec + 3 * idx, NULL, 1, -1, &rgeomid,
273+
normal);
274+
EXPECT_FLOAT_EQ(dist, dist_multiray[idx]);
275+
EXPECT_EQ(rgeomid, rgeomid_multiray[idx]);
276+
EXPECT_FLOAT_EQ(normal[0], normal_multiray[3*idx]);
277+
EXPECT_FLOAT_EQ(normal[1], normal_multiray[3*idx + 1]);
278+
EXPECT_FLOAT_EQ(normal[2], normal_multiray[3*idx + 2]);
279+
nhits += dist >= 0;
221280
}
222281
}
282+
EXPECT_GT(nhits, 10);
223283

224284
mj_deleteData(d);
225285
mj_deleteModel(m);

0 commit comments

Comments
 (0)