@@ -26,41 +26,57 @@ StatusCode ClearCache() {
2626 return StatusCode::kSuccess ;
2727}
2828
29+ template <typename Type>
30+ void FillCacheForPrecision (Queue& queue) {
31+ try {
32+ // Runs all the level 1 set-up functions that support all precisions
33+ Xswap<Type>(queue, nullptr );
34+ Xscal<Type>(queue, nullptr );
35+ Xcopy<Type>(queue, nullptr );
36+ Xaxpy<Type>(queue, nullptr );
37+ Xdot<Type>(queue, nullptr );
38+ Xnrm2<Type>(queue, nullptr );
39+ Xasum<Type>(queue, nullptr );
40+ Xsum<Type>(queue, nullptr );
41+ Xamax<Type>(queue, nullptr );
42+ Xmax<Type>(queue, nullptr );
43+ Xmin<Type>(queue, nullptr );
44+
45+ // Runs all the level 2 set-up functions that support all precisions
46+ Xgemv<Type>(queue, nullptr );
47+ Xgbmv<Type>(queue, nullptr );
48+ Xtrmv<Type>(queue, nullptr );
49+ Xtbmv<Type>(queue, nullptr );
50+ Xtpmv<Type>(queue, nullptr );
51+
52+ // Runs all the level 3 set-up functions that support all precisions
53+ Xgemm<Type>(queue, nullptr );
54+ Xsymm<Type>(queue, nullptr );
55+ Xsyrk<Type>(queue, nullptr );
56+ Xsyr2k<Type>(queue, nullptr );
57+ Xtrmm<Type>(queue, nullptr );
58+
59+ // Runs all the non-BLAS set-up functions
60+ Xomatcopy<Type>(queue, nullptr );
61+
62+ } catch (const RuntimeErrorCode& e) {
63+ if (e.status () != StatusCode::kNoDoublePrecision && e.status () != StatusCode::kNoHalfPrecision ) {
64+ throw ;
65+ }
66+ }
67+ }
68+
2969template <typename Real, typename Complex>
3070void FillCacheForPrecision (Queue& queue) {
3171 try {
32- // Runs all the level 1 set-up functions
33- Xswap<Real>(queue, nullptr );
34- Xswap<Complex>(queue, nullptr );
35- Xswap<Real>(queue, nullptr );
36- Xswap<Complex>(queue, nullptr );
37- Xscal<Real>(queue, nullptr );
38- Xscal<Complex>(queue, nullptr );
39- Xcopy<Real>(queue, nullptr );
40- Xcopy<Complex>(queue, nullptr );
41- Xaxpy<Real>(queue, nullptr );
42- Xaxpy<Complex>(queue, nullptr );
43- Xdot<Real>(queue, nullptr );
72+ FillCacheForPrecision<Real>(queue);
73+ FillCacheForPrecision<Complex>(queue);
74+
75+ // Runs all the level 1 set-up functions that don't support all precisions
4476 Xdotu<Complex>(queue, nullptr );
4577 Xdotc<Complex>(queue, nullptr );
46- Xnrm2<Real>(queue, nullptr );
47- Xnrm2<Complex>(queue, nullptr );
48- Xasum<Real>(queue, nullptr );
49- Xasum<Complex>(queue, nullptr );
50- Xsum<Real>(queue, nullptr );
51- Xsum<Complex>(queue, nullptr );
52- Xamax<Real>(queue, nullptr );
53- Xamax<Complex>(queue, nullptr );
54- Xmax<Real>(queue, nullptr );
55- Xmax<Complex>(queue, nullptr );
56- Xmin<Real>(queue, nullptr );
57- Xmin<Complex>(queue, nullptr );
58-
59- // Runs all the level 2 set-up functions
60- Xgemv<Real>(queue, nullptr );
61- Xgemv<Complex>(queue, nullptr );
62- Xgbmv<Real>(queue, nullptr );
63- Xgbmv<Complex>(queue, nullptr );
78+
79+ // Runs all the level 2 set-up functions that don't support all precisions
6480 Xhemv<Complex>(queue, nullptr );
6581 Xhbmv<Complex>(queue, nullptr );
6682 Xhpmv<Complex>(queue, nullptr );
@@ -69,10 +85,6 @@ void FillCacheForPrecision(Queue& queue) {
6985 Xspmv<Real>(queue, nullptr );
7086 Xtrmv<Real>(queue, nullptr );
7187 Xtrmv<Complex>(queue, nullptr );
72- Xtbmv<Real>(queue, nullptr );
73- Xtbmv<Complex>(queue, nullptr );
74- Xtpmv<Real>(queue, nullptr );
75- Xtpmv<Complex>(queue, nullptr );
7688 Xger<Real>(queue, nullptr );
7789 Xgeru<Complex>(queue, nullptr );
7890 Xgerc<Complex>(queue, nullptr );
@@ -85,24 +97,11 @@ void FillCacheForPrecision(Queue& queue) {
8597 Xsyr2<Real>(queue, nullptr );
8698 Xspr2<Real>(queue, nullptr );
8799
88- // Runs all the level 3 set-up functions
89- Xgemm<Real>(queue, nullptr );
90- Xgemm<Complex>(queue, nullptr );
91- Xsymm<Real>(queue, nullptr );
92- Xsymm<Complex>(queue, nullptr );
100+ // Runs all the level 3 set-up functions that don't support all precisions
93101 Xhemm<Complex>(queue, nullptr );
94- Xsyrk<Real>(queue, nullptr );
95- Xsyrk<Complex>(queue, nullptr );
96102 Xherk<Complex, Real>(queue, nullptr );
97- Xsyr2k<Real>(queue, nullptr );
98- Xsyr2k<Complex>(queue, nullptr );
99103 Xher2k<Complex, Real>(queue, nullptr );
100- Xtrmm<Real>(queue, nullptr );
101- Xtrmm<Complex>(queue, nullptr );
102-
103- // Runs all the non-BLAS set-up functions
104- Xomatcopy<Real>(queue, nullptr );
105- Xomatcopy<Complex>(queue, nullptr );
104+ Xspr2<Real>(queue, nullptr );
106105
107106 } catch (const RuntimeErrorCode& e) {
108107 if (e.status () != StatusCode::kNoDoublePrecision && e.status () != StatusCode::kNoHalfPrecision ) {
@@ -112,14 +111,14 @@ void FillCacheForPrecision(Queue& queue) {
112111}
113112
114113// Fills the cache with all binaries for a specific device
115- // TODO: Add half-precision FP16 set-up calls
116114StatusCode FillCache (const RawDeviceID device) {
117115 try {
118116 // Creates a sample context and queue to match the normal routine calling conventions
119117 auto device_cpp = Device (device);
120118 auto context = Context (device_cpp);
121119 auto queue = Queue (context, device_cpp);
122120
121+ FillCacheForPrecision<half>(queue);
123122 FillCacheForPrecision<float , float2>(queue);
124123 FillCacheForPrecision<double , double2>(queue);
125124
0 commit comments