@@ -194,32 +194,41 @@ void TestBlocks(const thrust::universal_vector<T>& k_cache, // [B, H, S,
194194 }
195195}
196196
197+ double get_memory_bandwidth () // -> GB/s
198+ {
199+ int clock_rate_khz{};
200+ int bus_width_bits{};
201+ cudaDeviceGetAttribute (&clock_rate_khz, cudaDevAttrMemoryClockRate, 0 );
202+ cudaDeviceGetAttribute (&bus_width_bits, cudaDevAttrGlobalMemoryBusWidth, 0 );
203+ return 2 . * (double )clock_rate_khz / 1e6 * (double )bus_width_bits / 8 .;
204+ }
205+
197206#define KV_INT8 0
198207
199- #define KV_INT4 0
208+ #define KV_INT4 1
200209
201- #define DECODING 0
210+ #define DECODING 1
202211
203212template <class T >
204213int test_attention ()
205214{
206215 AttentionParams<T> params{};
207216
208- constexpr size_t kHeadDim = 192 ;
217+ constexpr size_t kHeadDim = 128 ;
209218
210219#if DECODING
211220 // constexpr size_t kHeadNum = 32;
212221 // constexpr size_t kBatchSize = 64;
213222 constexpr size_t kHeadNum = 32 ;
214223 constexpr size_t KvHeadNum = kHeadNum / 4 ;
215- constexpr size_t kBatchSize = 1 ;
224+ constexpr size_t kBatchSize = 128 ;
216225 constexpr size_t kInputLen = 1 ;
217226 // constexpr size_t kSequenceLen = 63;
218227 // constexpr size_t kSequenceLen = 4095;
219228 // constexpr size_t kSequenceLen = 511;
220229 // constexpr size_t kSequenceLen = 2047;
221230 // constexpr size_t kSequenceLen = 4095;
222- constexpr size_t kSequenceLen = 8191 ;
231+ constexpr size_t kSequenceLen = 8 * 1024 - 1 ;
223232 // constexpr size_t kSequenceLen = 32767;
224233 // constexpr size_t kSequenceLen = 65535;
225234 // constexpr size_t kSequenceLen = 131071;
@@ -229,7 +238,7 @@ int test_attention()
229238 // constexpr size_t kSequenceLen = (1 << 22) - 1; // 4M
230239 // constexpr size_t kSequenceLen = (1 << 24) - 1; // 16M
231240 // constexpr int kSequenceLen = 2047;
232- constexpr int kBlockSz = 128 ;
241+ constexpr int kBlockSz = 64 ;
233242 constexpr int kMaxSplitK = 128 ;
234243#else
235244
@@ -430,11 +439,11 @@ int test_attention()
430439 params.qk = qk_buf.data ().get ();
431440 params.pr = pr_buf.data ().get ();
432441
433- Reference<T> reference (kDump ? Reference<T>::kUNFUSED : Reference<T>::kFLASH_ATTENTION , {});
434- // Reference<T> reference(Reference<T>::kUNFUSED, {});
442+ // Reference<T> reference(kDump ? Reference<T>::kUNFUSED : Reference<T>::kFLASH_ATTENTION, {});
443+ Reference<T> reference (Reference<T>::kUNFUSED , {});
435444 reference.Reshape (kInputLen , kContextLen , kHeadNum , kHeadDim , KvHeadNum, kBatchSize );
436445
437- for (int i = 0 ; i < 1 ; ++i) {
446+ for (int i = 0 ; i < 0 ; ++i) {
438447 reference.Execute (params.out , //
439448 k_cache_ref.data ().get (),
440449 v_cache_ref.data ().get (),
@@ -473,8 +482,16 @@ int test_attention()
473482
474483 std::vector<thrust::universal_vector<T>> outputs;
475484
476- for (int i = 0 ; i < std::max (kTestIter , 1 ); ++i) {
485+ std::vector<cudaEvent_t> ev_start (kTestIter );
486+ std::vector<cudaEvent_t> ev_end (kTestIter );
487+
488+ for (int i = 0 ; i < kTestIter ; ++i) {
489+ cudaEventCreate (&ev_start[i]);
490+ cudaEventCreate (&ev_end[i]);
491+ }
477492
493+ for (int i = 0 ; i < std::max (kTestIter , 1 ); ++i) {
494+ cudaEventRecord (ev_start[i]);
478495#if DECODING
479496 dispatchDecoding<T>(params);
480497#else
@@ -487,6 +504,8 @@ int test_attention()
487504 dispatchAttention (params);
488505 // params.linear_iter_params.kv_cache = std::exchange(tmp, nullptr);
489506#endif
507+ cudaEventRecord (ev_end[i]);
508+
490509 if (auto err = cudaGetLastError (); err != cudaSuccess) {
491510 std::cout << cudaGetErrorString (err) << " \n " ;
492511 return -1 ;
@@ -537,6 +556,20 @@ int test_attention()
537556 kQuantPolicy );
538557 cudaDeviceSynchronize ();
539558
559+ const size_t nbytes = blocks.size ();
560+
561+ const float peak_bw = get_memory_bandwidth ();
562+
563+ std::cout << " Device peak global memory bandwidth: " << peak_bw << " GB/s\n " ;
564+
565+ for (int i = 0 ; i < kTestIter ; ++i) {
566+ float ms{};
567+ cudaEventElapsedTime (&ms, ev_start[i], ev_end[i]);
568+ const float bw = nbytes / 1e9f / ms * 1000 .f ;
569+ const float percent = bw / peak_bw * 100 .f ;
570+ printf (" time %.3f ms, bw %.3f GB/s, %.3f %%\n " , ms, bw, percent);
571+ }
572+
540573 if (outputs.size () > 1 ) {
541574 std::cout << " Evaluating consistency..." << std::endl;
542575 for (size_t i = 1 ; i < outputs.size (); ++i) {
0 commit comments