diff --git a/tritonbench/operators/gdpa/gdpa_blackwell_tlx.py b/tritonbench/operators/gdpa/gdpa_blackwell_tlx.py index 9a44cd029..e738e6de2 100644 --- a/tritonbench/operators/gdpa/gdpa_blackwell_tlx.py +++ b/tritonbench/operators/gdpa/gdpa_blackwell_tlx.py @@ -1217,8 +1217,8 @@ def gdpa_forward_tlx( NUM_SMS = ( get_num_sms() or 1000000 ) * 8 # if num sms is None, use a large number so that it is a no-op - print("NUM_SMS", NUM_SMS) - print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads) + # print("NUM_SMS", NUM_SMS) + # print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads) q = expect_contiguous(query) k = expect_contiguous(key) @@ -1233,7 +1233,7 @@ def gdpa_forward_tlx( H = nheads y_dim = N_CTX_KV * Z x_dim = HEAD_DIM * H // G - USE_ON_DEVICE_TMA = True + USE_ON_DEVICE_TMA = False if not USE_ON_DEVICE_TMA: desc_q = TensorDescriptor( q, @@ -1268,7 +1268,7 @@ def grid_tma_persistent(META): ) activation_enum_int = activation_string_to_int(activation) - print(q.shape, k.shape, v.shape) + # print(q.shape, k.shape, v.shape) # print("activation_enum_int", activation, activation_enum_int) # print(query_offset) # print(key_offset)