Skip to content

Commit 0c259bd

Browse files
authored
[CK][CK Tile] Grouped Convolution Backward Weight set of fixes (#5387)
## Motivation Grouped Convolution Backward Weight split k fixes for CK tile kernels ## Technical Details - get k batch from kargs to get deduced k batch - multiply zeroing size by data type size - disable v6 (producing a incorrect results) ## Test Plan test_grouped_convnd_bwd_weight_tile ## Test Result Pass ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Ville Pietilä <>
1 parent be9be5f commit 0c259bd

File tree

6 files changed

+28
-11
lines changed

6 files changed

+28
-11
lines changed

projects/composablekernel/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct GroupedConvolutionBackwardWeightInvoker
126126
}
127127

128128
auto preprocess = [&]() {
129-
if(args.k_batch > 1)
129+
if(kargs.k_batch > 1)
130130
{
131131
ck_tile::hip_check_error(hipMemsetAsync(
132132
kargs.wei_ptr, 0, args.template GetWeightByte<WeiDataType>(), s.stream_id_));

projects/composablekernel/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
180180
}
181181

182182
auto preprocess = [&]() {
183-
if(args.k_batch > 1)
183+
if(kargs.k_batch > 1)
184184
ck_tile::hip_check_error(
185185
hipMemsetAsync(ws_args.wei_ptr,
186186
0,

projects/composablekernel/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ck_tile/builder/testing/testing.hpp"
77
#include "ck_tile/builder/testing/conv/fwd.hpp"
88
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
9+
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
910
#include "ck_tile/host/kernel_launch.hpp"
1011
#include "ck_tile/ops/gemm.hpp"
1112
#include "ck_tile/ops/grouped_convolution.hpp"
@@ -56,6 +57,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
5657
if(!Conv::IsSupportedArgument(kargs))
5758
return RunResult::not_supported("unsupported ck_tile arguments");
5859

60+
using Types = ck_tile::builder::factory::internal::TileConvTensorTypes<SIGNATURE.data_type>;
5961
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
6062
std::end(kargs.wei_g_k_c_xs_lengths.data),
6163
1,
@@ -64,10 +66,13 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
6466
auto preprocess = [&]() {
6567
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
6668
{
67-
if(args.k_batch > 1)
69+
if(kargs.k_batch > 1)
6870
{
6971
ck_tile::hip_check_error(
70-
hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_));
72+
hipMemsetAsync(kargs.wei_ptr,
73+
0,
74+
zeroing_size * sizeof(typename Types::EDataType),
75+
s_conf.stream_id_));
7176
}
7277
}
7378
};
@@ -156,7 +161,7 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
156161
auto preprocess = [&]() {
157162
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
158163
{
159-
if(args.k_batch > 1)
164+
if(kargs.k_batch > 1)
160165
{
161166
ck_tile::hip_check_error(
162167
hipMemsetAsync(ws_args.wei_ptr,

projects/composablekernel/experimental/grouped_convolution_tile_instances/generate_instances.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ def parse_bwd_weight_instances(instances, problem_name):
447447
if check_vectors(a_scalar_per_vector, b_scalar_per_vector, c_scalar_per_vector) == False:
448448
print(f"Skipping instance {instance_id} with irregular load since it's not supported yet.")
449449
continue
450-
450+
if pipeline_version == "V6":
451+
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
452+
continue
451453

452454
conv = ConvInstanceTemplateParams(
453455
spec,

projects/composablekernel/profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void run_cpu_validation(const ckt::Args<SIGNATURE>& args,
7373

7474
template <auto SIGNATURE>
7575
std::tuple<double, double>
76-
get_rtol_atol(const int num_accums, const int num_accums_split_k, const float max_accumulated_value)
76+
get_rtol_atol(const int num_accums, const int k_batch, const float max_accumulated_value)
7777
{
7878
using WeiDataType =
7979
std::conditional_t<SIGNATURE.data_type == ckb::DataType::FP32,
@@ -84,6 +84,8 @@ get_rtol_atol(const int num_accums, const int num_accums_split_k, const float ma
8484
using ComputeType = WeiDataType;
8585
using AccDataType = float;
8686

87+
// Assign middle value of the range for auto deduce
88+
const int num_accums_split_k = k_batch > 0 ? k_batch : 64;
8789
auto rtol = ck_tile::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
8890
num_accums / num_accums_split_k);
8991
auto atol = ck_tile::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
@@ -150,14 +152,17 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
150152
auto run_alg = [&](auto&& run_alg_func) {
151153
for(auto& k_batch : split_k_values)
152154
{
153-
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
155+
ckt::Args<SIGNATURE> args_k_batch = args;
156+
args_k_batch.k_batch = k_batch;
157+
std::tie(is_supported, avg_time, op_name) =
158+
run_alg_func(args_k_batch, inputs, outputs, s_conf);
154159
if(is_supported)
155160
{
156161
ckt::ValidationReport report;
157162
auto&& [rtol, atol] =
158163
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
159164
ckt::Outputs<SIGNATURE>::reflect(
160-
args,
165+
args_k_batch,
161166
[&](std::string_view name,
162167
const auto& desc,
163168
void* ckt::Outputs<SIGNATURE>::*ptr) {
@@ -182,7 +187,7 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
182187
<< " Is all zero:" << error.is_all_zero()
183188
<< " max err: " << error.max_error << std::endl;
184189
// Check with cpu verification to get a values
185-
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
190+
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
186191
}
187192
all_instances_valid = false;
188193
}

projects/composablekernel/profiler/src/profile_grouped_conv_bwd_weight_tile.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,12 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, const std::string& split_k,
136136
split_k,
137137
inputs.get(),
138138
outputs.get(),
139-
ck_tile::stream_config{nullptr, time_kernel});
139+
ck_tile::stream_config{nullptr,
140+
time_kernel,
141+
0 /*log_level*/,
142+
5 /*cold_iters*/,
143+
50 /*nrepeat_*/,
144+
true /*is_gpu_timer_*/});
140145
if(time_kernel)
141146
{
142147
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name

0 commit comments

Comments
 (0)