You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -795,9 +796,9 @@ template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx
795
796
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
796
797
template void spmm_coo_very_sparse_naive<signedchar, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
797
798
798
-
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
799
-
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
800
-
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
799
+
template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
800
+
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
801
+
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
801
802
802
803
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
803
804
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
Copy file name to clipboardExpand all lines: csrc/ops.cuh
+4-4Lines changed: 4 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -171,16 +171,16 @@ void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, i
171
171
voidstrided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
172
172
longlongint strideA, longlongint strideB, longlongint strideC, int batchCount);
173
173
174
-
template <int DTYPE_OUT, int SCALE_ROWS> intigemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
174
+
template <int DTYPE_OUT, int SCALE_ROWS> intigemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
175
175
176
176
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> voidtransform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
177
177
voidcutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
178
-
voiddequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols);
178
+
voiddequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
179
179
voidgetColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
180
-
voidgetRowStats(half *A, float *rowStats, float threshold, int rows, int cols);
180
+
voidgetRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
Copy file name to clipboardExpand all lines: csrc/pythonInterface.cpp
+18-18Lines changed: 18 additions & 18 deletions
Original file line number
Diff line number
Diff line change
@@ -175,14 +175,14 @@ void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRo
175
175
voidextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
176
176
voidextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
177
177
178
-
intigemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
179
-
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
178
+
intigemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
179
+
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
180
180
}
181
-
intigemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
182
-
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
181
+
intigemmlt_8(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
182
+
return igemmlt<8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
183
183
}
184
-
intigemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
185
-
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
184
+
intigemmlt_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
185
+
return igemmlt<8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
186
186
}
187
187
188
188
voidspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
intcigemmlt_32(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
312
-
returnigemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
311
+
intcigemmlt_32(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
312
+
returnigemmlt_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
313
313
}
314
-
intcigemmlt_8(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
315
-
returnigemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
314
+
intcigemmlt_8(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
315
+
returnigemmlt_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
316
316
}
317
-
intcigemmlt_8_rowscale(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) {
318
-
returnigemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc);
317
+
intcigemmlt_8_rowscale(Context *context, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
318
+
returnigemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
voidcdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
0 commit comments