Skip to content

Commit c632a10

Browse files
authored
Architectural changes to support multi-return feature (#285)
* Architectural changes to support multi-return feature
1 parent 323d983 commit c632a10

File tree

5 files changed

+87
-11
lines changed

5 files changed

+87
-11
lines changed

src/api/apiCore.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,27 @@ void TapeCore::tape_node_gaussian_noise_distance(const YAML::Node& yamlNode, Pla
13111311
state.nodes.insert({nodeId, node});
13121312
}
13131313

1314+
RGL_API rgl_status_t rgl_node_multi_return_switch(rgl_node_t* node, rgl_return_type_t return_type)
1315+
{
1316+
auto status = rglSafeCall([&]() {
1317+
RGL_API_LOG("rgl_node_multi_return_switch(node={}, return_type={})", repr(node), return_type);
1318+
CHECK_ARG(node != nullptr);
1319+
1320+
createOrUpdateNode<MultiReturnSwitchNode>(node, return_type);
1321+
});
1322+
TAPE_HOOK(node, return_type);
1323+
return status;
1324+
}
1325+
1326+
void TapeCore::tape_node_multi_return_switch(const YAML::Node& yamlNode, PlaybackState& state)
1327+
{
1328+
auto nodeId = yamlNode[0].as<TapeAPIObjectID>();
1329+
auto return_type = static_cast<rgl_return_type_t>(yamlNode[1].as<int>());
1330+
rgl_node_t node = state.nodes.contains(nodeId) ? state.nodes.at(nodeId) : nullptr;
1331+
rgl_node_multi_return_switch(&node, return_type);
1332+
state.nodes.insert({nodeId, node});
1333+
}
1334+
13141335
rgl_status_t rgl_node_is_alive(rgl_node_t node, bool* out_alive)
13151336
{
13161337
auto status = rglSafeCall([&]() {

src/gpu/nodeKernels.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,25 +229,30 @@ __global__ void kProcessBeamSamplesFirstLast(size_t beamCount, int samplesPerBea
229229
LIMIT(beamCount);
230230

231231
const auto beamIdx = tid;
232-
int firstIdx = 0;
233-
int lastIdx = 0;
232+
int firstIdx = -1;
233+
int lastIdx = -1;
234234
for (int sampleIdx = 0; sampleIdx < samplesPerBeam; ++sampleIdx) {
235235
if (beamSamples.isHit[beamIdx * samplesPerBeam + sampleIdx] == 0) {
236236
continue;
237237
}
238-
if (beamSamples.distance[beamIdx * samplesPerBeam + sampleIdx] <
239-
beamSamples.distance[beamIdx * samplesPerBeam + firstIdx]) {
238+
auto currentFirstDistance = firstIdx >= 0 ? beamSamples.distance[beamIdx * samplesPerBeam + firstIdx] : FLT_MAX;
239+
auto currentLastDistance = lastIdx >= 0 ? beamSamples.distance[beamIdx * samplesPerBeam + lastIdx] : -FLT_MAX;
240+
if (beamSamples.distance[beamIdx * samplesPerBeam + sampleIdx] < currentFirstDistance) {
240241
firstIdx = sampleIdx;
241242
}
242-
if (beamSamples.distance[beamIdx * samplesPerBeam + sampleIdx] >
243-
beamSamples.distance[beamIdx * samplesPerBeam + lastIdx]) {
243+
if (beamSamples.distance[beamIdx * samplesPerBeam + sampleIdx] > currentLastDistance) {
244244
lastIdx = sampleIdx;
245245
}
246246
}
247-
first.xyz[beamIdx] = beamSamples.xyz[beamIdx * samplesPerBeam + firstIdx];
248-
first.distance[beamIdx] = beamSamples.distance[beamIdx * samplesPerBeam + firstIdx];
249-
last.xyz[beamIdx] = beamSamples.xyz[beamIdx * samplesPerBeam + lastIdx];
250-
last.distance[beamIdx] = beamSamples.distance[beamIdx * samplesPerBeam + lastIdx];
247+
bool isHit = firstIdx >= 0; // Note that firstHit >= 0 implies lastHit >= 0
248+
first.isHit[beamIdx] = isHit;
249+
last.isHit[beamIdx] = isHit;
250+
if (isHit) {
251+
first.xyz[beamIdx] = beamSamples.xyz[beamIdx * samplesPerBeam + firstIdx];
252+
first.distance[beamIdx] = beamSamples.distance[beamIdx * samplesPerBeam + firstIdx];
253+
last.xyz[beamIdx] = beamSamples.xyz[beamIdx * samplesPerBeam + lastIdx];
254+
last.distance[beamIdx] = beamSamples.distance[beamIdx * samplesPerBeam + lastIdx];
255+
}
251256
}
252257

253258

@@ -327,4 +332,4 @@ void gpuProcessBeamSamplesFirstLast(cudaStream_t stream, size_t beamCount, int s
327332
MultiReturnPointers first, MultiReturnPointers last)
328333
{
329334
run(kProcessBeamSamplesFirstLast, stream, beamCount, samplesPerBeam, beamSamples, first, last);
330-
}
335+
}

src/graph/NodesCore.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ struct RaytraceNode : IPointsNode
128128
return std::const_pointer_cast<const IAnyArray>(fieldData.at(field));
129129
}
130130

131+
IAnyArray::ConstPtr getFieldDataMultiReturn(rgl_field_t field, rgl_return_type_t);
132+
131133
// RaytraceNode specific
132134
void setVelocity(const Vec3f& linearVelocity, const Vec3f& angularVelocity);
133135
void enableRayDistortion(bool enabled) { doApplyDistortion = enabled; }
@@ -193,6 +195,31 @@ struct RaytraceNode : IPointsNode
193195
void setFields(const std::set<rgl_field_t>& fields);
194196
};
195197

198+
struct MultiReturnSwitchNode : IPointsNodeSingleInput
199+
{
200+
using Ptr = std::shared_ptr<MultiReturnSwitchNode>;
201+
void setParameters(rgl_return_type_t returnType) { this->returnType = returnType; }
202+
203+
// Node
204+
void validateImpl() override
205+
{
206+
IPointsNodeSingleInput::validateImpl();
207+
rtxInput = getExactlyOneInputOfType<RaytraceNode>();
208+
}
209+
210+
void enqueueExecImpl() override {}
211+
212+
// Data getters
213+
IAnyArray::ConstPtr getFieldData(rgl_field_t field) override
214+
{
215+
return rtxInput->getFieldDataMultiReturn(field, returnType);
216+
}
217+
218+
private:
219+
rgl_return_type_t returnType;
220+
RaytraceNode::Ptr rtxInput;
221+
};
222+
196223
struct TransformPointsNode : IPointsNodeSingleInput
197224
{
198225
using Ptr = std::shared_ptr<TransformPointsNode>;

src/graph/RaytraceNode.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ void RaytraceNode::enqueueExecImpl()
137137
mrSamples.getPointers(), mrFirst.getPointers(), mrLast.getPointers());
138138
}
139139

140+
IAnyArray::ConstPtr RaytraceNode::getFieldDataMultiReturn(rgl_field_t field, rgl_return_type_t type)
141+
{
142+
if (type == RGL_RETURN_TYPE_FIRST) {
143+
switch (field) {
144+
case XYZ_VEC3_F32: return mrFirst.xyz;
145+
case DISTANCE_F32: return mrFirst.distance;
146+
case IS_HIT_I32: return mrFirst.isHit;
147+
default: throw InvalidPipeline(fmt::format("Multi-Return not supported for this field ({})", toString(field)));
148+
}
149+
}
150+
if (type == RGL_RETURN_TYPE_LAST) {
151+
switch (field) {
152+
case XYZ_VEC3_F32: return mrLast.xyz;
153+
case DISTANCE_F32: return mrLast.distance;
154+
case IS_HIT_I32: return mrLast.isHit;
155+
default: throw InvalidPipeline(fmt::format("Multi-Return not supported for this field ({})", toString(field)));
156+
}
157+
}
158+
throw InvalidPipeline(fmt::format("Unknown multi-return type ({})", type));
159+
}
160+
140161
void RaytraceNode::setFields(const std::set<rgl_field_t>& fields)
141162
{
142163
auto keyViewer = std::views::keys(fieldData);

src/tape/TapeCore.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class TapeCore
6868
static void tape_node_gaussian_noise_angular_ray(const YAML::Node& yamlNode, PlaybackState& state);
6969
static void tape_node_gaussian_noise_angular_hitpoint(const YAML::Node& yamlNode, PlaybackState& state);
7070
static void tape_node_gaussian_noise_distance(const YAML::Node& yamlNode, PlaybackState& state);
71+
static void tape_node_multi_return_switch(const YAML::Node& yamlNode, PlaybackState& state);
7172

7273
// Called once in the translation unit
7374
static inline bool autoExtendTapeFunctions = std::invoke([]() {
@@ -123,6 +124,7 @@ class TapeCore
123124
TAPE_CALL_MAPPING("rgl_node_gaussian_noise_angular_ray", TapeCore::tape_node_gaussian_noise_angular_ray),
124125
TAPE_CALL_MAPPING("rgl_node_gaussian_noise_angular_hitpoint", TapeCore::tape_node_gaussian_noise_angular_hitpoint),
125126
TAPE_CALL_MAPPING("rgl_node_gaussian_noise_distance", TapeCore::tape_node_gaussian_noise_distance),
127+
TAPE_CALL_MAPPING("rgl_node_multi_return_switch", TapeCore::tape_node_multi_return_switch),
126128
};
127129
TapePlayer::extendTapeFunctions(tapeFunctions);
128130
return true;

0 commit comments

Comments
 (0)