@@ -996,131 +996,6 @@ void explicit_gemm_conv_1D_cpu(
996996 encoder.add_temporaries (std::move (temps));
997997}
998998
999- void explicit_gemm_conv_2D_cpu (
1000- const array& in,
1001- const array& wt,
1002- array out,
1003- const std::vector<int >& padding_lo,
1004- const std::vector<int >& padding_hi,
1005- const std::vector<int >& wt_strides,
1006- const std::vector<int >& wt_dilation,
1007- Stream stream) {
1008- const int N = in.shape (0 ); // Batch size, should be the same as out.shape(0)
1009- const int iH = in.shape (1 ); // Input spatial dim
1010- const int iW = in.shape (2 ); // Input spatial dim
1011- const int oH = out.shape (1 ); // Output spatial dim
1012- const int oW = out.shape (2 ); // Output spatial dim
1013- const int O = wt.shape (0 ); // Out channels
1014- const int C = wt.shape (3 ); // In channels
1015- const int wH = wt.shape (1 ); // Weight spatial dim
1016- const int wW = wt.shape (2 ); // Weight spatial dim
1017-
1018- auto conv_dtype = out.dtype ();
1019- auto & encoder = cpu::get_command_encoder (stream);
1020-
1021- // Pad input
1022- Shape padded_shape = {
1023- N,
1024- iH + padding_lo[0 ] + padding_hi[0 ],
1025- iW + padding_lo[1 ] + padding_hi[1 ],
1026- C};
1027- array in_padded (padded_shape, conv_dtype, nullptr , {});
1028-
1029- // Fill with zeros
1030- std::vector<array> temps;
1031- temps.push_back (array (0 , conv_dtype));
1032- copy_cpu (temps.back (), in_padded, CopyType::Scalar, stream);
1033-
1034- // Pick input slice from padded
1035- size_t data_offset = padding_lo[0 ] * in_padded.strides ()[1 ] +
1036- padding_lo[1 ] * in_padded.strides ()[2 ];
1037- array in_padded_slice (in.shape (), in_padded.dtype (), nullptr , {});
1038- in_padded_slice.copy_shared_buffer (
1039- in_padded,
1040- in_padded.strides (),
1041- in_padded.flags (),
1042- in_padded_slice.size (),
1043- data_offset);
1044- temps.push_back (in_padded_slice);
1045-
1046- // Copy input values into the slice
1047- copy_cpu_inplace (in, in_padded_slice, CopyType::GeneralGeneral, stream);
1048-
1049- // Make strided view
1050- Shape strided_shape = {N, oH, oW, wH, wW, C};
1051-
1052- Strides strided_strides = {
1053- in_padded.strides ()[0 ],
1054- in_padded.strides ()[1 ] * wt_strides[0 ],
1055- in_padded.strides ()[2 ] * wt_strides[1 ],
1056- in_padded.strides ()[1 ],
1057- in_padded.strides ()[2 ],
1058- in_padded.strides ()[3 ]};
1059- auto flags = in_padded.flags ();
1060-
1061- array in_strided_view (strided_shape, in_padded.dtype (), nullptr , {});
1062- in_strided_view.copy_shared_buffer (
1063- in_padded, strided_strides, flags, in_strided_view.size (), 0 );
1064-
1065- // Materialize strided view
1066- Shape strided_reshape = {N * oH * oW, wH * wW * C};
1067- array in_strided (strided_reshape, in_strided_view.dtype (), nullptr , {});
1068- copy_cpu (in_strided_view, in_strided, CopyType::General, stream);
1069- temps.push_back (in_strided);
1070-
1071- // Check wt dtype and prepare
1072- auto gemm_wt = wt;
1073- auto gemm_out = out;
1074-
1075- if (wt.dtype () != float32 || !wt.flags ().row_contiguous ) {
1076- auto ctype =
1077- wt.flags ().row_contiguous ? CopyType::Vector : CopyType::General;
1078- gemm_wt = array (wt.shape (), float32, nullptr , {});
1079- copy_cpu (wt, gemm_wt, ctype, stream);
1080- temps.push_back (gemm_wt);
1081- }
1082-
1083- if (out.dtype () != float32) {
1084- gemm_out = array (out.shape (), float32, nullptr , {});
1085- gemm_out.set_data (allocator::malloc (gemm_out.nbytes ()));
1086- temps.push_back (gemm_out);
1087- }
1088-
1089- encoder.set_input_array (in_strided);
1090- encoder.set_input_array (gemm_wt);
1091- encoder.set_output_array (gemm_out);
1092-
1093- encoder.dispatch ([in_strided_ptr = in_strided.data <float >(),
1094- gemm_wt_ptr = gemm_wt.data <float >(),
1095- gemm_out_ptr = gemm_out.data <float >(),
1096- strided_reshape = std::move (strided_reshape),
1097- O]() {
1098- // Perform gemm
1099- cblas_sgemm (
1100- CblasRowMajor,
1101- CblasNoTrans, // no trans A
1102- CblasTrans, // transB
1103- strided_reshape[0 ], // M
1104- O, // N
1105- strided_reshape[1 ], // K
1106- 1 .0f , // alpha
1107- in_strided_ptr,
1108- strided_reshape[1 ], // lda
1109- gemm_wt_ptr,
1110- strided_reshape[1 ], // ldb
1111- 0 .0f , // beta
1112- gemm_out_ptr,
1113- O // ldc
1114- );
1115- });
1116-
1117- // Copy results if needed
1118- if (out.dtype () != float32) {
1119- copy_cpu_inplace (gemm_out, out, CopyType::Vector, stream);
1120- }
1121- encoder.add_temporaries (std::move (temps));
1122- }
1123-
1124999void explicit_gemm_conv_ND_cpu (
11251000 const array& in,
11261001 const array& wt,
0 commit comments