Skip to content

Commit 16c4c8f

Browse files
committed
add runtime params for conv1d
1 parent a9a684f commit 16c4c8f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/conv1d.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ main(int argc, char **argv) {
140140
span1d_t(d_bias, channel_out), k, channel_in, length};
141141

142142
WorkItemDispatch wi_dispatch;
143-
wi_dispatch.set_ideal_sizes(1024, n0, n1, n2);
143+
wi_dispatch.set_ideal_sizes(params.pref_wg_size, n0, n1, n2);
144144
auto max_elem_local_mem =
145145
Q.get_device().get_info<sycl::info::device::local_mem_size>() /
146146
sizeof(real_t);
147147
wi_dispatch.adjust_sizes_mem_limit(max_elem_local_mem, n1);
148148

149149
WorkGroupDispatch wg_dispatch;
150-
wg_dispatch.set_num_work_groups(n0, n2, 1, 1, wi_dispatch.w0_,
151-
wi_dispatch.w2_);
150+
wg_dispatch.set_num_work_groups(n0, n2, params.seq_size0, params.seq_size2,
151+
wi_dispatch.w0_, wi_dispatch.w2_);
152152

153153
BkmaOptimParams optim_params{{1, n0, n0}, // BatchConfig1D dispatch_d0
154154
{1, n2, n2}, // BatchConfig1D dispatch_d2

0 commit comments

Comments
 (0)