Skip to content

Commit 1ba3459

Browse files
committed
Simplify place_partition and add const-correctness
- Simplify compute_subplaces_no_handle to use a single loop for all cases (scalar places, 1-element grids, multi-element grids all handled uniformly) - Change place parameter to const exec_place& since we only read from it - Make exec_place::get_place() const (was missing const qualifier) Made-with: Cursor
1 parent e7a1e9e commit 1ba3459

File tree

2 files changed

+8
-22
lines changed

2 files changed

+8
-22
lines changed

cudax/include/cuda/experimental/__stf/places/place_partition.cuh

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ public:
103103
* @param place The execution place to partition
104104
* @param scope Partitioning granularity (must be cuda_device when no handle is provided)
105105
*/
106-
place_partition(exec_place place, place_partition_scope scope)
106+
place_partition(const exec_place& place, place_partition_scope scope)
107107
{
108108
#if _CCCL_CTK_BELOW(12, 4)
109109
_CCCL_ASSERT(scope != place_partition_scope::green_context, "Green contexts need an async resource handle.");
110110
#endif // _CCCL_CTK_BELOW(12, 4)
111-
compute_subplaces_no_handle(mv(place), scope);
111+
compute_subplaces_no_handle(place, scope);
112112
}
113113

114114
/** @brief Partition a vector of execution places into a single vector of subplaces (with async handle).
@@ -291,11 +291,11 @@ private:
291291
#endif
292292

293293
// If the scope requires no handle
294-
compute_subplaces_no_handle(mv(place), scope);
294+
compute_subplaces_no_handle(place, scope);
295295
}
296296

297297
/** @brief Compute the subplaces of a place at the specified granularity (scope) into the sub_places vector */
298-
void compute_subplaces_no_handle(exec_place place, place_partition_scope scope)
298+
void compute_subplaces_no_handle(const exec_place& place, place_partition_scope scope)
299299
{
300300
#if _CCCL_CTK_BELOW(12, 4)
301301
_CCCL_ASSERT(scope != place_partition_scope::green_context, "Green contexts scope need an async resource handle.");
@@ -304,23 +304,9 @@ private:
304304

305305
if (scope == place_partition_scope::cuda_device)
306306
{
307-
if (place.size() > 1)
307+
for (size_t i = 0; i < place.size(); ++i)
308308
{
309-
// Multi-element grid: extract all places
310-
for (size_t i = 0; i < place.size(); ++i)
311-
{
312-
sub_places.push_back(place.get_place(i));
313-
}
314-
}
315-
else if (place.is_device())
316-
{
317-
// Scalar device place
318-
sub_places.push_back(mv(place));
319-
}
320-
else
321-
{
322-
// 1-element grid or other scalar place: extract the underlying place
323-
sub_places.push_back(place.get_place(0));
309+
sub_places.push_back(place.get_place(i));
324310
}
325311
return;
326312
}

cudax/include/cuda/experimental/__stf/places/places.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,15 +571,15 @@ public:
571571
*
572572
* For scalar places, idx must be 0 and returns the place itself.
573573
*/
574-
exec_place get_place(size_t idx)
574+
exec_place get_place(size_t idx) const
575575
{
576576
return exec_place(pimpl->get_place(idx));
577577
}
578578

579579
/**
580580
* @brief Get the sub-place at the given multi-dimensional position
581581
*/
582-
exec_place get_place(pos4 p)
582+
exec_place get_place(pos4 p) const
583583
{
584584
return get_place(get_dims().get_index(p));
585585
}

0 commit comments

Comments
 (0)