@@ -94,6 +94,21 @@ int main(int argc, char *argv[])
94
94
GetMySubarray (&my_subarray);
95
95
InitDeviceArrays (&A_device[0 ], &A_device[1 ], q, &my_subarray);
96
96
97
+ #ifdef GROUP_SIZE_DEFAULT
98
+ int work_group_size = GROUP_SIZE_DEFAULT;
99
+ #else
100
+ int work_group_size =
101
+ q.get_device ().get_info <sycl::info::device::max_work_group_size>();
102
+ #endif
103
+
104
+ if ((Nx % work_group_size) != 0 ) {
105
+ if (my_subarray.rank == 0 ) {
106
+ printf (" For simplification, sycl::info::device::max_work_group_size should be divider of X dimention of array\n " );
107
+ printf (" Please adjust matrix size, or define GROUP_SIZE_DEFAULT\n " );
108
+ printf (" sycl::info::device::max_work_group_size=%d Nx=%d (%d)\n " , work_group_size, Nx, work_group_size % Nx);
109
+ MPI_Abort (MPI_COMM_WORLD, -1 );
110
+ }
111
+ }
97
112
/* Create RMA window using device memory */
98
113
MPI_Win_create (A_device[0 ],
99
114
sizeof (double ) * (my_subarray.x_size + 2 ) * (my_subarray.y_size + 2 ),
@@ -116,18 +131,24 @@ int main(int argc, char *argv[])
116
131
{
117
132
/* Calculate values on borders to initiate communications early */
118
133
q.submit ([&](auto & h) {
119
- h.parallel_for (sycl::range (my_subarray.x_size ), [ =] (auto index) {
120
- int column = index[0 ];
121
- int idx = XY_2_IDX (column, 0 , my_subarray);
122
- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
123
- + a[idx - ROW_SIZE (my_subarray)]
124
- + a[idx + ROW_SIZE (my_subarray)]);
125
-
126
- idx = XY_2_IDX (column, my_subarray.y_size - 1 , my_subarray);
127
- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
128
- + a[idx - ROW_SIZE (my_subarray)]
129
- + a[idx + ROW_SIZE (my_subarray)]);
130
-
134
+ h.parallel_for (sycl::nd_range<1 >(work_group_size, work_group_size),
135
+ [=](sycl::nd_item<1 > item) {
136
+ int column = item.get_global_id (0 );
137
+ int col_per_wg = my_subarray.x_size / work_group_size;
138
+
139
+ int my_x_lb = col_per_wg * local_id;
140
+ int my_x_ub = my_x_lb + col_per_wg;
141
+
142
+ for (int column = my_x_lb; column < my_x_ub; column ++) {
143
+ int idx = XY_2_IDX (column, 0 , my_subarray);
144
+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
145
+ + a[idx - ROW_SIZE (my_subarray)]
146
+ + a[idx + ROW_SIZE (my_subarray)]);
147
+ idx = XY_2_IDX (column, my_subarray.y_size - 1 , my_subarray);
148
+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
149
+ + a[idx - ROW_SIZE (my_subarray)]
150
+ + a[idx + ROW_SIZE (my_subarray)]);
151
+ }
131
152
});
132
153
}).wait ();
133
154
}
@@ -149,11 +170,23 @@ int main(int argc, char *argv[])
149
170
/* Recalculate internal points in parallel with communications */
150
171
{
151
172
q.submit ([&](auto & h) {
152
- h.parallel_for (sycl::range (my_subarray.x_size , my_subarray.y_size - 2 ), [ =] (auto index) {
153
- int idx = XY_2_IDX (index[0 ], index[1 ] + 1 , my_subarray);
154
- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
155
- + a[idx - ROW_SIZE (my_subarray)]
156
- + a[idx + ROW_SIZE (my_subarray)]);
173
+ h.parallel_for (sycl::nd_range<1 >(work_group_size, work_group_size),
174
+ [=](sycl::nd_item<1 > item) {
175
+ int local_id = item.get_local_id ();
176
+ int col_per_wg = my_subarray.x_size / work_group_size;
177
+
178
+ int my_x_lb = col_per_wg * local_id;
179
+ int my_x_ub = my_x_lb + col_per_wg;
180
+
181
+ /* Recalculate internal points in parallel with comunications */
182
+ for (int row = 1 ; row < my_subarray.y_size - 1 ; ++row) {
183
+ for (int column = my_x_lb; column < my_x_ub; column ++) {
184
+ int idx = XY_2_IDX (column, row, my_subarray);
185
+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
186
+ + a[idx - ROW_SIZE (my_subarray)]
187
+ + a[idx + ROW_SIZE (my_subarray)]);
188
+ }
189
+ }
157
190
});
158
191
}).wait ();
159
192
}
0 commit comments