|
50 | 50 | ] |
51 | 51 |
|
52 | 52 |
|
53 | | -def _call_syrk(x1, x2): |
54 | | - """ |
55 | | - Check to see if `syrk` can be called instead of `gemm`. |
56 | | -
|
57 | | - It is assumed that x1 and x2 are usm_ndarray objects. These arrays have |
58 | | - already been validated to be 2-dimensional and contiguous. Therefore, this |
59 | | - function only verifies the following: Both arrays reference the same |
60 | | - memory. The number of rows in x1 equals the number of columns in x2. If one |
61 | | - array is C-contiguous, the other must be F-contiguous. |
62 | | -
|
63 | | - """ |
64 | | - call_syrk = False |
65 | | - if ( |
66 | | - x1._pointer == x2._pointer |
67 | | - and x1.shape[0] == x2.shape[1] |
68 | | - and x1.flags.c_contiguous != x2.flags.c_contiguous |
69 | | - and x1.flags.f_contiguous != x2.flags.f_contiguous |
70 | | - ): |
71 | | - call_syrk = True |
72 | | - |
73 | | - return call_syrk |
74 | | - |
75 | | - |
76 | 53 | def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"): |
77 | 54 | """ |
78 | 55 | Determines the output array data type. |
@@ -541,6 +518,29 @@ def _get_signature(func): |
541 | 518 | return signature, distinct_core |
542 | 519 |
|
543 | 520 |
|
| 521 | +def _is_syrk_compatible(x1, x2): |
| 522 | + """ |
| 523 | + Check to see if `syrk` can be called instead of `gemm`. |
| 524 | + Input arrays have already been validated to be 2-dimensional. |
| 525 | +
|
| 526 | + """ |
| 527 | + # Must share data (same base buffer) |
| 528 | + if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer: |
| 529 | + return False |
| 530 | + |
| 531 | + # Result must be square |
| 532 | + if x1.shape[0] != x2.shape[1]: |
| 533 | + return False |
| 534 | + |
| 535 | + # Strides must match transpose pattern |
| 536 | + x1_strides = x1.strides |
| 537 | + x2_strides = x2.strides |
| 538 | + if x1_strides[0] != x2_strides[1] or x1_strides[1] != x2_strides[0]: |
| 539 | + return False |
| 540 | + |
| 541 | + return True |
| 542 | + |
| 543 | + |
544 | 544 | def _shape_error(shape1, shape2, func, err_msg): |
545 | 545 | """Validate the shapes of input and output arrays.""" |
546 | 546 |
|
@@ -983,6 +983,11 @@ def dpnp_multiplication( |
983 | 983 | x1 = dpnp.reshape(x1, x1_shape[-2:]) |
984 | 984 | x2 = dpnp.reshape(x2, x2_shape[-2:]) |
985 | 985 | res_shape = (x1_shape[-2], x2_shape[-1]) |
| 986 | + if _is_syrk_compatible(x1, x2): |
| 987 | + call_flag = "syrk" |
| 988 | + res_dtype_orig = res_dtype |
| 989 | + if dpnp.issubdtype(res_dtype, dpnp.integer): |
| 990 | + res_dtype = dpnp.default_float_type(x1.device) |
986 | 991 | elif x1_base_is_1D: |
987 | 992 | # TODO: implement gemv_batch to use it here with transpose |
988 | 993 | call_flag = "gemm_batch" |
@@ -1088,21 +1093,17 @@ def dpnp_multiplication( |
1088 | 1093 | depends=_manager.submitted_events, |
1089 | 1094 | ) |
1090 | 1095 | _manager.add_event_pair(ht_ev, gemv_ev) |
| 1096 | + elif call_flag == "syrk": |
| 1097 | + _manager = dpu.SequentialOrderManager[exec_q] |
| 1098 | + ht_ev, gemv_ev = bi._syrk( |
| 1099 | + exec_q, |
| 1100 | + dpnp.get_usm_ndarray(x1), |
| 1101 | + dpnp.get_usm_ndarray(result), |
| 1102 | + depends=_manager.submitted_events, |
| 1103 | + ) |
| 1104 | + _manager.add_event_pair(ht_ev, gemv_ev) |
1091 | 1105 | elif call_flag == "gemm": |
1092 | | - x1_usm = dpnp.get_usm_ndarray(x1) |
1093 | | - x2_usm = dpnp.get_usm_ndarray(x2) |
1094 | | - call_syrk = _call_syrk(x1_usm, x2_usm) |
1095 | | - if call_syrk: |
1096 | | - _manager = dpu.SequentialOrderManager[exec_q] |
1097 | | - ht_ev, gemv_ev = bi._syrk( |
1098 | | - exec_q, |
1099 | | - x1_usm, |
1100 | | - dpnp.get_usm_ndarray(result), |
1101 | | - depends=_manager.submitted_events, |
1102 | | - ) |
1103 | | - _manager.add_event_pair(ht_ev, gemv_ev) |
1104 | | - else: |
1105 | | - result = _gemm_matmul(exec_q, x1_usm, x2_usm, result) |
| 1106 | + result = _gemm_matmul(exec_q, x1, x2, result) |
1106 | 1107 | else: |
1107 | 1108 | assert call_flag == "gemm_batch" |
1108 | 1109 | result = _gemm_batch_matmul(exec_q, x1, x2, result) |
@@ -1130,6 +1131,9 @@ def dpnp_multiplication( |
1130 | 1131 | elif res_shape != result_shape: |
1131 | 1132 | result = dpnp.reshape(result, result_shape) |
1132 | 1133 |
|
| 1134 | + if call_flag == "syrk" and res_dtype_orig != res_dtype: |
| 1135 | + result = result.astype(res_dtype_orig) |
| 1136 | + |
1133 | 1137 | if out is None: |
1134 | 1138 | if axes is not None: |
1135 | 1139 | # Move the data back to the appropriate axes of the result array |
|
0 commit comments