|
58 | 58 |
|
59 | 59 | namespace at::native { |
60 | 60 |
|
61 | | -<<<<<<< HEAD |
62 | | -namespace { |
63 | | - |
64 | | -// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 |
65 | | -c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { |
66 | | - if (resolve_conj && tensor.is_conj()) { |
67 | | - return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj()); |
68 | | - } else { |
69 | | - return c10::MaybeOwned<Tensor>::borrowed(tensor); |
70 | | - } |
71 | | -} |
72 | | - |
73 | | -c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { |
74 | | - if (tensor.is_non_overlapping_and_dense()) { // common case |
75 | | - transpose_tensor = tensor.is_contiguous(); |
76 | | - return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); |
77 | | - } |
78 | | - IntArrayRef tensor_strides = tensor.strides(); |
79 | | - IntArrayRef tensor_sizes = tensor.sizes(); |
80 | | - if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { |
81 | | - transpose_tensor = false; |
82 | | - return resolve_conj_if_indicated(tensor, !transpose_result); |
83 | | - } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { |
84 | | - transpose_tensor = true; |
85 | | - return resolve_conj_if_indicated(tensor, transpose_result); |
86 | | - } else { |
87 | | - transpose_tensor = true; |
88 | | - return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); |
89 | | - } |
90 | | -} |
91 | | - |
92 | | -c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { |
93 | | - if (tensor.is_non_overlapping_and_dense()) { // common case |
94 | | - transpose_tensor = tensor.is_contiguous(); |
95 | | - return resolve_conj_if_indicated(tensor, true); |
96 | | - } |
97 | | - |
98 | | - IntArrayRef tensor_strides = tensor.strides(); |
99 | | - IntArrayRef tensor_sizes = tensor.sizes(); |
100 | | - if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { |
101 | | - transpose_tensor = false; |
102 | | - return resolve_conj_if_indicated(tensor, true); |
103 | | - } else if ((tensor_strides[1] == 1) && |
104 | | - (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { |
105 | | - transpose_tensor = true; |
106 | | - return resolve_conj_if_indicated(tensor, true); |
107 | | - } else { |
108 | | - transpose_tensor = true; |
109 | | - return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); |
110 | | - } |
111 | | -} |
112 | | - |
113 | | -using at::cuda::blas::ScalingType; |
114 | | - |
115 | | -/** |
116 | | - * @brief Prepares matrices for CUBLAS operation |
117 | | - * |
118 | | - * This constructor prepares tensors for CUBLAS |
119 | | - * The main difference is that PyTorch uses row-major as the default and |
120 | | - * CUBLAS expects column-major. |
121 | | - * |
122 | | - * @details |
123 | | - * To enable row-major output while using CUBLAS, |
124 | | - * we use the mathematical identity that (A × B)^T = B^T × A^T. |
125 | | - * |
126 | | - * Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major) |
127 | | - * T = row-major, N = col-major |
128 | | - * |
129 | | - * Example: |
130 | | - * For matrices A (M×K)(row-major) and B (K×N)(row-major): |
131 | | - * - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major) |
132 | | - * - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N) |
133 | | - * - However, since the output form cublas is column-major this is |
134 | | - * - equivalent to an output of size MxN row-major as expected |
135 | | - * |
136 | | - * The transpose flags are derived from the layouts of the passed in tensors |
137 | | - * |
138 | | - * If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted |
139 | | - * to their unpacked values to match what cuBLAS expects. |
140 | | - * |
141 | | - * @param mat1 First input matrix |
142 | | - * @param mat2 Second input matrix |
143 | | - * @param c Output matrix (result) |
144 | | - * @param scale_a Optional scaling factor for first matrix |
145 | | - * @param scale_b Optional scaling factor for second matrix |
146 | | - * @param scale_result Optional scaling factor for result |
147 | | - */ |
148 | | -struct cublasCommonArgs { |
149 | | - cublasCommonArgs( |
150 | | - const Tensor& mat1, |
151 | | - const Tensor& mat2, |
152 | | - Tensor& c, |
153 | | - const std::optional<Tensor>& scale_a = std::nullopt, |
154 | | - const std::optional<Tensor>& scale_b = std::nullopt, |
155 | | - const std::optional<Tensor>& scale_result = std::nullopt, |
156 | | - const std::optional<ScalingType>& scaling_choice_a = std::nullopt, |
157 | | - const std::optional<ScalingType>& scaling_choice_b = std::nullopt) { |
158 | | - bool transpose_result = false, transpose_a = false, transpose_b = false; |
159 | | - result = prepare_matrix_for_cublas(c, transpose_result); |
160 | | - mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); |
161 | | - matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); |
162 | | - |
163 | | - // Handle scale tensors if provided |
164 | | - if (scale_a && scale_b) { |
165 | | - // By default since we return in row-major we run the gemm |
166 | | - // as B.T @ A.T, check transpose_result to determine if we flip the scales |
167 | | - scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); |
168 | | - scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); |
169 | | - scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; |
170 | | - scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); |
171 | | - scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); |
172 | | - scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; |
173 | | - } |
174 | | - |
175 | | - if (scale_result) { |
176 | | - scale_result_ptr = scale_result->data_ptr(); |
177 | | - scale_result_dtype = scale_result->scalar_type(); |
178 | | - } |
179 | | - |
180 | | - // Update transpose flags |
181 | | - if (transpose_result) { |
182 | | - transpose_a = !transpose_a; |
183 | | - transpose_b = !transpose_b; |
184 | | - } |
185 | | - |
186 | | - auto sizes_a = mata->sizes(); |
187 | | - auto sizes_b = matb->sizes(); |
188 | | - |
189 | | - m = sizes_a[transpose_result ? 1 : 0]; |
190 | | - k = sizes_a[transpose_result ? 0 : 1]; |
191 | | - n = sizes_b[transpose_result ? 0 : 1]; |
192 | | - lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); |
193 | | - ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); |
194 | | - result_ld = result->stride(transpose_result ? 0 : 1); |
195 | | - transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; |
196 | | - transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; |
197 | | - |
198 | | - // cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing |
199 | | - // if the gemm operands are in packed float4 |
200 | | - if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) { |
201 | | - k = k * 2; |
202 | | - lda = lda * 2; |
203 | | - ldb = ldb * 2; |
204 | | - } |
205 | | - } |
206 | | - |
207 | | - // Matrix members |
208 | | - char transa, transb; |
209 | | - int64_t m, n, k; |
210 | | - int64_t lda, ldb, result_ld; |
211 | | - c10::MaybeOwned<Tensor> mata, matb, result; |
212 | | - |
213 | | - // Scale members |
214 | | - void* scale_mata_ptr = nullptr; |
215 | | - void* scale_matb_ptr = nullptr; |
216 | | - void* scale_result_ptr = nullptr; |
217 | | - std::optional<c10::ScalarType> scale_mata_dtype; |
218 | | - std::optional<ScalingType> scaling_mata_type; |
219 | | - std::optional<c10::ScalarType> scale_matb_dtype; |
220 | | - std::optional<ScalingType> scaling_matb_type; |
221 | | - std::optional<c10::ScalarType> scale_result_dtype; |
222 | | -}; |
223 | | -} // namespace |
224 | | -======= |
225 | 61 | using at::blas::ScalingType; |
226 | 62 | using at::blas::SwizzleType; |
227 | | ->>>>>>> upstream/main |
228 | 63 |
|
229 | 64 | c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) { |
230 | 65 | IntArrayRef tensor_strides = tensor.strides(); |
|
0 commit comments