Skip to content

Commit cf3ac50

Browse files
committed
review changes: code cleanup
1 parent 271efa5 commit cf3ac50

File tree

13 files changed

+122
-118
lines changed

13 files changed

+122
-118
lines changed

projects/composablekernel/test/ck_tile/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ add_subdirectory(core)
6666
add_subdirectory(epilogue)
6767
add_subdirectory(atomic_add_op)
6868
add_subdirectory(fmha)
69+
# TODO: The Universal GEMM tile engine test will be either removed
70+
# or moved to the appropriate location in future work.
6971
# add_subdirectory(gemm_tile_engine)
7072
add_subdirectory(pooling)
7173
add_subdirectory(grouped_conv)

projects/composablekernel/tile_engine/ops/common/utils.hpp

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
// SPDX-License-Identifier: MIT
33

44
#pragma once
5+
#include <hip/hip_version.h>
56
#include <iostream>
67
#include <functional>
78
#include <tuple>
89
#include <exception>
910
#include <sstream>
1011
#include <vector>
1112
#include <string>
13+
#include <cstdlib>
1214

1315
#include "ck_tile/core.hpp"
1416
#include "ck_tile/host.hpp"
@@ -54,17 +56,6 @@ struct PerformanceResult
5456
default: throw std::invalid_argument("Unsupported metric type");
5557
}
5658
}
57-
58-
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
59-
{
60-
os << "{\n"
61-
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
62-
<< ",\n"
63-
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
64-
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
65-
<< "}";
66-
return os;
67-
}
6859
};
6960

7061
template <typename Problem>
@@ -78,42 +69,46 @@ struct KernelInstance
7869
{
7970
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
8071
}
81-
82-
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
83-
{
84-
os << "{\n"
85-
<< " \"name\": \"" << obj.name_ << "\",\n"
86-
<< " \"problem\": " << obj.problem_ << ",\n"
87-
<< " \"perf_result\": " << obj.perf_result_ << "\n"
88-
<< "}";
89-
return os;
90-
}
9172
};
9273

93-
struct Setting
74+
template <typename Problem>
75+
std::ostream& operator<<(std::ostream& os, const KernelInstance<Problem>& obj)
9476
{
95-
int n_warmup_;
96-
int n_repeat_;
97-
bool is_gpu_timer_;
98-
int verify_;
99-
int init_method_;
100-
bool log_;
101-
std::string csv_filename_;
102-
bool flush_cache_;
103-
int rotating_count_;
104-
bool json_output_;
77+
os << "{\n"
78+
<< " \"name\": \"" << obj.name_ << "\",\n"
79+
<< " \"problem\": " << obj.problem_ << ",\n"
80+
<< " \"perf_result\": " << obj.perf_result_ << "\n"
81+
<< "}";
82+
return os;
83+
}
84+
85+
std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
86+
{
87+
os << "{\n"
88+
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ << ",\n"
89+
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
90+
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
91+
<< "}";
92+
return os;
93+
}
94+
95+
struct Settings
96+
{
97+
int n_warmup;
98+
int n_repeat;
99+
bool is_gpu_timer;
100+
int verify;
101+
int init_method;
102+
bool log;
103+
std::string csv_filename;
104+
bool flush_cache;
105+
int rotating_count;
106+
bool json_output;
105107
};
106108

107109
inline std::string get_rocm_version()
108110
{
109-
std::ifstream version_file("/opt/rocm/.info/version");
110-
if(version_file.is_open())
111-
{
112-
std::string version;
113-
std::getline(version_file, version);
114-
return version;
115-
}
116-
return "Unknown";
111+
return std::to_string(HIP_VERSION_MAJOR) + "." + std::to_string(HIP_VERSION_MINOR);
117112
}
118113

119114
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>

projects/composablekernel/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "gemm/gemm_benchmark.hpp"
1515

1616
#pragma clang diagnostic push
17-
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-seggestions"
17+
#pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions"
1818

1919
// Data types and Layouts are defined by the generated kernel headers
2020
// No hardcoded type definitions here to avoid conflicts

projects/composablekernel/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python3
12
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
23
# SPDX-License-Identifier: MIT
34

@@ -9,7 +10,7 @@
910

1011

1112
def _import_gemm_benchmark():
12-
"""Import validation utilities from commons directory."""
13+
"""Import gemm benchmark from parent directory."""
1314
current_dir = os.path.dirname(os.path.abspath(__file__))
1415
parent_dir = os.path.dirname(current_dir)
1516

projects/composablekernel/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark_single.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
5858
layout_d0,
5959
layout_d1};
6060

61-
// Create Setting struct
62-
Setting setting{arg_parser.get_int("warmup"),
63-
arg_parser.get_int("repeat"),
64-
arg_parser.get_bool("timer"),
65-
arg_parser.get_int("verify"),
66-
arg_parser.get_int("init"),
67-
arg_parser.get_bool("log"),
68-
arg_parser.get_str("csv_filename"),
69-
arg_parser.get_bool("flush_cache"),
70-
arg_parser.get_int("rotating_count"),
71-
arg_parser.get_bool("json_output")};
61+
// Create Settings struct
62+
Settings setting{arg_parser.get_int("warmup"),
63+
arg_parser.get_int("repeat"),
64+
arg_parser.get_bool("timer"),
65+
arg_parser.get_int("verify"),
66+
arg_parser.get_int("init"),
67+
arg_parser.get_bool("log"),
68+
arg_parser.get_str("csv_filename"),
69+
arg_parser.get_bool("flush_cache"),
70+
arg_parser.get_int("rotating_count"),
71+
arg_parser.get_bool("json_output")};
7272

7373
// Get the profiler instance
7474
auto& profiler = GemmMultiDProfiler::instance(setting);

projects/composablekernel/tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_profiler.hpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class GemmMultiDProfiler : public GemmProfiler<GemmMultiDProfiler,
2727
ck_tile::GemmMultiDHostArgs<DsDataType::size()>>;
2828
using BaseGemm::benchmark;
2929

30-
GemmMultiDProfiler(Setting setting)
30+
GemmMultiDProfiler(Settings setting)
3131
: GemmProfiler<GemmMultiDProfiler,
3232
GemmMultiDProblem,
3333
ck_tile::GemmMultiDHostArgs<DsDataType::size()>>(setting)
@@ -141,18 +141,23 @@ class GemmMultiDProfiler : public GemmProfiler<GemmMultiDProfiler,
141141
gemm_multi_d_problem.stride_c_,
142142
is_row_major(layout_c)));
143143

144-
if(setting_.verify_)
144+
if(setting_.verify)
145145
{
146146
gemm_multi_d_host_reference(
147-
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, c_m_n_host_result);
147+
setting_.verify, a_m_k, b_k_n, d0_m_n, d1_m_n, c_m_n_host_result);
148148
}
149149

150150
for(auto& callable : callables)
151151
{
152-
auto kernel_run_result =
153-
callable(gemm_multi_d_args,
154-
ck_tile::stream_config{
155-
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
152+
auto kernel_run_result = callable(gemm_multi_d_args,
153+
ck_tile::stream_config{nullptr,
154+
true,
155+
setting_.log,
156+
setting_.n_warmup,
157+
setting_.n_repeat,
158+
setting_.is_gpu_timer,
159+
setting_.flush_cache,
160+
setting_.rotating_count});
156161
process_result(gemm_multi_d_problem,
157162
c_m_n_dev_buf,
158163
c_m_n_host_result,

projects/composablekernel/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python3
12
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
23
# SPDX-License-Identifier: MIT
34

@@ -9,7 +10,7 @@
910

1011

1112
def _import_gemm_benchmark():
12-
"""Import validation utilities from commons directory."""
13+
"""Import gemm benchmark from parent directory."""
1314
current_dir = os.path.dirname(os.path.abspath(__file__))
1415
parent_dir = os.path.dirname(current_dir)
1516

projects/composablekernel/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
4949
layout_c,
5050
arg_parser.get_bool("structured_sparsity")};
5151

52-
// Create Setting struct
53-
Setting setting{arg_parser.get_int("warmup"),
54-
arg_parser.get_int("repeat"),
55-
arg_parser.get_bool("timer"),
56-
arg_parser.get_int("verify"),
57-
arg_parser.get_int("init"),
58-
arg_parser.get_bool("log"),
59-
arg_parser.get_str("csv_filename"),
60-
arg_parser.get_bool("flush_cache"),
61-
arg_parser.get_int("rotating_count"),
62-
arg_parser.get_bool("json_output")};
52+
// Create Settings struct
53+
Settings setting{arg_parser.get_int("warmup"),
54+
arg_parser.get_int("repeat"),
55+
arg_parser.get_bool("timer"),
56+
arg_parser.get_int("verify"),
57+
arg_parser.get_int("init"),
58+
arg_parser.get_bool("log"),
59+
arg_parser.get_str("csv_filename"),
60+
arg_parser.get_bool("flush_cache"),
61+
arg_parser.get_int("rotating_count"),
62+
arg_parser.get_bool("json_output")};
6363

6464
// Get the profiler instance
6565
auto& profiler = GemmPreshuffleProfiler::instance(setting);

projects/composablekernel/tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_profiler.hpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class GemmPreshuffleProfiler
1616
using BaseGemm = GemmProfiler<GemmPreshuffleProfiler, GemmProblem, ck_tile::GemmHostArgs>;
1717
using BaseGemm::benchmark;
1818

19-
GemmPreshuffleProfiler(Setting setting)
19+
GemmPreshuffleProfiler(Settings setting)
2020
: GemmProfiler<GemmPreshuffleProfiler, GemmProblem, ck_tile::GemmHostArgs>(setting)
2121
{
2222
}
@@ -43,17 +43,17 @@ class GemmPreshuffleProfiler
4343
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
4444
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
4545

46-
if(setting_.init_method_ == 0)
46+
if(setting_.init_method == 0)
4747
{
4848
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
4949
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
5050
}
51-
else if(setting_.init_method_ == 1)
51+
else if(setting_.init_method == 1)
5252
{
5353
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
5454
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
5555
}
56-
else if(setting_.init_method_ == 2)
56+
else if(setting_.init_method == 2)
5757
{
5858
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
5959
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
@@ -73,9 +73,9 @@ class GemmPreshuffleProfiler
7373
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
7474
c_m_n_ref.SetZero();
7575

76-
if(setting_.verify_)
76+
if(setting_.verify)
7777
{
78-
gemm_host_reference(setting_.verify_,
78+
gemm_host_reference(setting_.verify,
7979
a_m_k,
8080
b_k_n,
8181
c_m_n_ref,
@@ -89,7 +89,7 @@ class GemmPreshuffleProfiler
8989
gemm_problem.stride_c_);
9090
}
9191

92-
// Kerenl Execution
92+
// Kernel Execution
9393

9494
a_m_k_dev_buf.ToDevice(a_m_k.data());
9595
c_m_n_dev_buf.SetZero();
@@ -126,12 +126,12 @@ class GemmPreshuffleProfiler
126126
auto kernel_run_result = callable(gemm_args,
127127
ck_tile::stream_config{nullptr,
128128
true,
129-
setting_.log_,
130-
setting_.n_warmup_,
131-
setting_.n_repeat_,
132-
setting_.is_gpu_timer_,
133-
setting_.flush_cache_,
134-
setting_.rotating_count_});
129+
setting_.log,
130+
setting_.n_warmup,
131+
setting_.n_repeat,
132+
setting_.is_gpu_timer,
133+
setting_.flush_cache,
134+
setting_.rotating_count});
135135

136136
process_result(
137137
gemm_problem, c_m_n_dev_buf, c_m_n_ref, c_m_n_dev_result, kernel_run_result);

projects/composablekernel/tile_engine/ops/gemm/gemm_profiler.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ template <typename Gemm, typename Problem, typename GemmArgs>
1919
class GemmProfiler
2020
{
2121
public:
22-
static Gemm& instance(Setting setting)
22+
static Gemm& instance(Settings setting)
2323
{
2424
static Gemm instance{setting};
2525
return instance;
@@ -68,7 +68,7 @@ class GemmProfiler
6868
ck_tile::static_for<0, DDataType::size(), 1>{}([&](auto i) {
6969
using DType = ck_tile::remove_cvref_t<std::tuple_element_t<i, DDataType>>;
7070
num_byte += sizeof(DType) * gemm_problem.m_ * gemm_problem.n_;
71-
flop += sizeof(DType) * gemm_problem.m_ * gemm_problem.n_;
71+
flop += gemm_problem.m_ * gemm_problem.n_;
7272
});
7373
}
7474

@@ -77,7 +77,7 @@ class GemmProfiler
7777
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
7878
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
7979

80-
if(setting_.log_ > 0 && !setting_.json_output_)
80+
if(setting_.log > 0 && !setting_.json_output)
8181
{
8282
std::cout << kernel_instance << std::endl;
8383
}
@@ -90,7 +90,7 @@ class GemmProfiler
9090
split_k = gemm_problem.split_k_;
9191
}
9292
bool verified_correct =
93-
!setting_.verify_ ||
93+
!setting_.verify ||
9494
compare<Problem>(name, gemm_problem.k_, split_k, c_m_n_dev_result, c_m_n_host_result);
9595

9696
if(verified_correct)
@@ -119,7 +119,7 @@ class GemmProfiler
119119
b.perf_result_, a.perf_result_, metric);
120120
});
121121

122-
if(setting_.json_output_)
122+
if(setting_.json_output)
123123
{
124124
// Output clean JSON only
125125
std::cout << kernel_instance << std::endl;
@@ -132,9 +132,9 @@ class GemmProfiler
132132
std::cout << "**********************************" << std::endl;
133133
}
134134

135-
if(!setting_.csv_filename_.empty())
135+
if(!setting_.csv_filename.empty())
136136
{
137-
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
137+
std::ofstream file(setting_.csv_filename + ".csv", std::ios::app);
138138

139139
if(!file.is_open())
140140
{
@@ -182,9 +182,9 @@ class GemmProfiler
182182

183183
protected:
184184
virtual ~GemmProfiler() { kernel_instances_.clear(); }
185-
GemmProfiler(Setting setting) : setting_(setting) {}
185+
GemmProfiler(Settings setting) : setting_(setting) {}
186186

187-
Setting setting_;
187+
Settings setting_;
188188

189189
std::vector<KernelInstance<Problem>> kernel_instances_;
190190
};

0 commit comments

Comments
 (0)