@@ -1869,8 +1869,45 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
18691869 return nullptr ;
18701870}
18711871
1872+ static bool supports_tensor (const struct ggml_tensor * op) {
1873+ if (op->op == GGML_OP_MUL_MAT &&
1874+ op->src [0 ]->buffer &&
1875+ (ggml_n_dims (op->src [0 ]) == 2 ) &&
1876+ op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1877+ ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1878+
1879+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1880+ return false ;
1881+ }
1882+
1883+ if (op->src [1 ]->type == GGML_TYPE_F32) {
1884+ return true ;
1885+ }
1886+
1887+ } else if (op->op == GGML_OP_MUL_MAT_ID && op->src [0 ]->buffer &&
1888+ (ggml_n_dims (op->src [0 ]) == 3 ) &&
1889+ op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1890+ ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1891+
1892+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1893+ return false ;
1894+ }
1895+
1896+ if (op->src [1 ]->type == GGML_TYPE_F32) {
1897+ return true ;
1898+ }
1899+ }
1900+ return false ;
1901+ }
1902+
18721903static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
1873- tensor->extra = (void *) const_cast <ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type (tensor));
1904+ if (tensor->op == GGML_OP_NONE) {
1905+ tensor->extra = (void *) const_cast <ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type (tensor));
1906+ }
1907+
1908+ if (supports_tensor (tensor)) {
1909+ tensor->src [0 ]->extra = (void *) const_cast <ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type (tensor->src [0 ]));
1910+ }
18741911
18751912 GGML_UNUSED (buffer);
18761913 return GGML_STATUS_SUCCESS;
@@ -1918,39 +1955,7 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
19181955namespace ggml ::cpu::repack {
19191956class extra_buffer_type : ggml::cpu::extra_buffer_type {
19201957 bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
1921- if ( op->op == GGML_OP_MUL_MAT &&
1922- op->src [0 ]->buffer &&
1923- (ggml_n_dims (op->src [0 ]) == 2 ) &&
1924- op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1925- ggml_repack_get_optimal_repack_type (op->src [0 ])
1926- ) {
1927- if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1928- return false ;
1929- }
1930- if (op->src [1 ]->type == GGML_TYPE_F32) {
1931- return true ;
1932- }
1933- // if (op->src[1]->type == GGML_TYPE_Q8_0) {
1934- // return true;
1935- // }
1936- // may be possible if Q8_0 packed...
1937- } else if (op->op == GGML_OP_MUL_MAT_ID
1938- && op->src [0 ]->buffer
1939- && (ggml_n_dims (op->src [0 ]) == 3 )
1940- && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()
1941- && ggml_repack_get_optimal_repack_type (op->src [0 ])
1942- ) {
1943- if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1944- return false ;
1945- }
1946- if (op->src [1 ]->type == GGML_TYPE_F32) {
1947- return true ;
1948- }
1949- // if (op->src[1]->type == GGML_TYPE_Q8_0) {
1950- // return true;
1951- // }
1952- }
1953- return false ;
1958+ return supports_tensor (op);
19541959 }
19551960
19561961 ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
0 commit comments