Skip to content

Commit 7d2790d

Browse files
authored
Merge branch 'master' into master
2 parents a711de0 + 2b511e6 commit 7d2790d

File tree

1 file changed

+49
-50
lines changed

1 file changed

+49
-50
lines changed

src/api_common.cpp

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -26,41 +26,57 @@ StatusCode ClearCache() {
2626
return StatusCode::kSuccess;
2727
}
2828

29+
template <typename Type>
30+
void FillCacheForPrecision(Queue& queue) {
31+
try {
32+
// Runs all the level 1 set-up functions that support all precisions
33+
Xswap<Type>(queue, nullptr);
34+
Xscal<Type>(queue, nullptr);
35+
Xcopy<Type>(queue, nullptr);
36+
Xaxpy<Type>(queue, nullptr);
37+
Xdot<Type>(queue, nullptr);
38+
Xnrm2<Type>(queue, nullptr);
39+
Xasum<Type>(queue, nullptr);
40+
Xsum<Type>(queue, nullptr);
41+
Xamax<Type>(queue, nullptr);
42+
Xmax<Type>(queue, nullptr);
43+
Xmin<Type>(queue, nullptr);
44+
45+
// Runs all the level 2 set-up functions that support all precisions
46+
Xgemv<Type>(queue, nullptr);
47+
Xgbmv<Type>(queue, nullptr);
48+
Xtrmv<Type>(queue, nullptr);
49+
Xtbmv<Type>(queue, nullptr);
50+
Xtpmv<Type>(queue, nullptr);
51+
52+
// Runs all the level 3 set-up functions that support all precisions
53+
Xgemm<Type>(queue, nullptr);
54+
Xsymm<Type>(queue, nullptr);
55+
Xsyrk<Type>(queue, nullptr);
56+
Xsyr2k<Type>(queue, nullptr);
57+
Xtrmm<Type>(queue, nullptr);
58+
59+
// Runs all the non-BLAS set-up functions
60+
Xomatcopy<Type>(queue, nullptr);
61+
62+
} catch (const RuntimeErrorCode& e) {
63+
if (e.status() != StatusCode::kNoDoublePrecision && e.status() != StatusCode::kNoHalfPrecision) {
64+
throw;
65+
}
66+
}
67+
}
68+
2969
template <typename Real, typename Complex>
3070
void FillCacheForPrecision(Queue& queue) {
3171
try {
32-
// Runs all the level 1 set-up functions
33-
Xswap<Real>(queue, nullptr);
34-
Xswap<Complex>(queue, nullptr);
35-
Xswap<Real>(queue, nullptr);
36-
Xswap<Complex>(queue, nullptr);
37-
Xscal<Real>(queue, nullptr);
38-
Xscal<Complex>(queue, nullptr);
39-
Xcopy<Real>(queue, nullptr);
40-
Xcopy<Complex>(queue, nullptr);
41-
Xaxpy<Real>(queue, nullptr);
42-
Xaxpy<Complex>(queue, nullptr);
43-
Xdot<Real>(queue, nullptr);
72+
FillCacheForPrecision<Real>(queue);
73+
FillCacheForPrecision<Complex>(queue);
74+
75+
// Runs all the level 1 set-up functions that don't support all precisions
4476
Xdotu<Complex>(queue, nullptr);
4577
Xdotc<Complex>(queue, nullptr);
46-
Xnrm2<Real>(queue, nullptr);
47-
Xnrm2<Complex>(queue, nullptr);
48-
Xasum<Real>(queue, nullptr);
49-
Xasum<Complex>(queue, nullptr);
50-
Xsum<Real>(queue, nullptr);
51-
Xsum<Complex>(queue, nullptr);
52-
Xamax<Real>(queue, nullptr);
53-
Xamax<Complex>(queue, nullptr);
54-
Xmax<Real>(queue, nullptr);
55-
Xmax<Complex>(queue, nullptr);
56-
Xmin<Real>(queue, nullptr);
57-
Xmin<Complex>(queue, nullptr);
58-
59-
// Runs all the level 2 set-up functions
60-
Xgemv<Real>(queue, nullptr);
61-
Xgemv<Complex>(queue, nullptr);
62-
Xgbmv<Real>(queue, nullptr);
63-
Xgbmv<Complex>(queue, nullptr);
78+
79+
// Runs all the level 2 set-up functions that don't support all precisions
6480
Xhemv<Complex>(queue, nullptr);
6581
Xhbmv<Complex>(queue, nullptr);
6682
Xhpmv<Complex>(queue, nullptr);
@@ -69,10 +85,6 @@ void FillCacheForPrecision(Queue& queue) {
6985
Xspmv<Real>(queue, nullptr);
7086
Xtrmv<Real>(queue, nullptr);
7187
Xtrmv<Complex>(queue, nullptr);
72-
Xtbmv<Real>(queue, nullptr);
73-
Xtbmv<Complex>(queue, nullptr);
74-
Xtpmv<Real>(queue, nullptr);
75-
Xtpmv<Complex>(queue, nullptr);
7688
Xger<Real>(queue, nullptr);
7789
Xgeru<Complex>(queue, nullptr);
7890
Xgerc<Complex>(queue, nullptr);
@@ -85,24 +97,11 @@ void FillCacheForPrecision(Queue& queue) {
8597
Xsyr2<Real>(queue, nullptr);
8698
Xspr2<Real>(queue, nullptr);
8799

88-
// Runs all the level 3 set-up functions
89-
Xgemm<Real>(queue, nullptr);
90-
Xgemm<Complex>(queue, nullptr);
91-
Xsymm<Real>(queue, nullptr);
92-
Xsymm<Complex>(queue, nullptr);
100+
// Runs all the level 3 set-up functions that don't support all precisions
93101
Xhemm<Complex>(queue, nullptr);
94-
Xsyrk<Real>(queue, nullptr);
95-
Xsyrk<Complex>(queue, nullptr);
96102
Xherk<Complex, Real>(queue, nullptr);
97-
Xsyr2k<Real>(queue, nullptr);
98-
Xsyr2k<Complex>(queue, nullptr);
99103
Xher2k<Complex, Real>(queue, nullptr);
100-
Xtrmm<Real>(queue, nullptr);
101-
Xtrmm<Complex>(queue, nullptr);
102-
103-
// Runs all the non-BLAS set-up functions
104-
Xomatcopy<Real>(queue, nullptr);
105-
Xomatcopy<Complex>(queue, nullptr);
104+
Xspr2<Real>(queue, nullptr);
106105

107106
} catch (const RuntimeErrorCode& e) {
108107
if (e.status() != StatusCode::kNoDoublePrecision && e.status() != StatusCode::kNoHalfPrecision) {
@@ -112,14 +111,14 @@ void FillCacheForPrecision(Queue& queue) {
112111
}
113112

114113
// Fills the cache with all binaries for a specific device
115-
// TODO: Add half-precision FP16 set-up calls
116114
StatusCode FillCache(const RawDeviceID device) {
117115
try {
118116
// Creates a sample context and queue to match the normal routine calling conventions
119117
auto device_cpp = Device(device);
120118
auto context = Context(device_cpp);
121119
auto queue = Queue(context, device_cpp);
122120

121+
FillCacheForPrecision<half>(queue);
123122
FillCacheForPrecision<float, float2>(queue);
124123
FillCacheForPrecision<double, double2>(queue);
125124

0 commit comments

Comments
 (0)