Skip to content

Commit 9c59677

Browse files
committed
minor fixes
1 parent ccacddb commit 9c59677

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ struct reduce<Config, BinOp, 2, device_capabilities>
121121
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
122122
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
123123
lv1_val = reduction1(lv1_val);
124-
scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
124+
125+
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
126+
scratchAccessor.template set<scalar_t>(0, lv1_val[Config::ItemsPerInvocation_1-1]);
125127
}
126128
scratchAccessor.workgroupExecutionAndMemoryBarrier();
127129

@@ -130,7 +132,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
130132
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
131133
{
132134
scalar_t reduce_val;
133-
scratchAccessor.template get<scalar_t>(glsl::gl_SubgroupInvocationID(),reduce_val);
135+
scratchAccessor.template get<scalar_t>(0,reduce_val);
134136
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, hlsl::promote<vector_lv0_t>(reduce_val));
135137
}
136138
}
@@ -179,9 +181,9 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179181
[unroll]
180182
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
181183
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
182-
vector_lv1_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv1_t>(BinOp::identity), lv1_val, bool(invocationIndex));
183-
shiftedInput = inclusiveScan1(shiftedInput);
184-
scratchAccessor.template set<scalar_t>(invocationIndex, shiftedInput[Config::ItemsPerInvocation_1-1]);
184+
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
185+
lv1_val = inclusiveScan1(lv1_val);
186+
scratchAccessor.template set<scalar_t>(invocationIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
185187
}
186188
scratchAccessor.workgroupExecutionAndMemoryBarrier();
187189

@@ -284,7 +286,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
284286
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
285287
{
286288
scalar_t reduce_val;
287-
scratchAccessor.template get<scalar_t>(glsl::gl_SubgroupInvocationID(),reduce_val);
289+
scratchAccessor.template get<scalar_t>(0,reduce_val);
288290
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, reduce_val);
289291
}
290292
}
@@ -353,8 +355,8 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
353355
[unroll]
354356
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
355357
scratchAccessor.template get<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
356-
vector_lv2_t shiftedInput = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(invocationIndex));
357-
shiftedInput = inclusiveScan2(shiftedInput);
358+
lv2_val[0] = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val[0], bool(invocationIndex));
359+
vector_lv2_t shiftedScan = inclusiveScan2(lv2_val);
358360

359361
// combine with level 1, only last element of each
360362
[unroll]
@@ -363,7 +365,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
363365
scalar_t last_val;
364366
scratchAccessor.template get<scalar_t>((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val);
365367
scalar_t val = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(i));
366-
val = binop(last_val, shiftedInput[Config::ItemsPerInvocation_2-1]);
368+
val = binop(last_val, shiftedScan[Config::ItemsPerInvocation_2-1]);
367369
scratchAccessor.template set<scalar_t>((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val);
368370
}
369371
}

0 commit comments

Comments
 (0)