@@ -1872,9 +1872,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
18721872static bool supports_tensor (const struct ggml_tensor * op) {
18731873 if (op->op == GGML_OP_MUL_MAT &&
18741874 op->src [0 ]->buffer &&
1875- (ggml_n_dims (op->src [0 ]) == 2 ) &&
1876- op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1877- ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1875+ (ggml_n_dims (op->src [0 ]) == 2 ) && ggml_repack_get_optimal_repack_type (op->src [0 ])) {
18781876
18791877 if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
18801878 return false ;
@@ -1885,9 +1883,7 @@ static bool supports_tensor(const struct ggml_tensor * op) {
18851883 }
18861884
18871885 } else if (op->op == GGML_OP_MUL_MAT_ID && op->src [0 ]->buffer &&
1888- (ggml_n_dims (op->src [0 ]) == 3 ) &&
1889- op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1890- ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1886+ (ggml_n_dims (op->src [0 ]) == 3 ) && ggml_repack_get_optimal_repack_type (op->src [0 ])) {
18911887
18921888 if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
18931889 return false ;
@@ -1903,10 +1899,12 @@ static bool supports_tensor(const struct ggml_tensor * op) {
19031899static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
19041900 if (tensor->op == GGML_OP_NONE) {
19051901 tensor->extra = (void *) const_cast <ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type (tensor));
1902+ tensor->buffer = buffer;
19061903 }
19071904
19081905 if (supports_tensor (tensor)) {
19091906 tensor->src [0 ]->extra = (void *) const_cast <ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type (tensor->src [0 ]));
1907+ tensor->src [0 ]->buffer = buffer;
19101908 }
19111909
19121910 GGML_UNUSED (buffer);
@@ -1955,7 +1953,39 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
19551953namespace ggml ::cpu::repack {
19561954class extra_buffer_type : ggml::cpu::extra_buffer_type {
19571955 bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
1958- return supports_tensor (op);
1956+ if ( op->op == GGML_OP_MUL_MAT &&
1957+ op->src [0 ]->buffer &&
1958+ (ggml_n_dims (op->src [0 ]) == 2 ) &&
1959+ op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type () &&
1960+ ggml_repack_get_optimal_repack_type (op->src [0 ])
1961+ ) {
1962+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1963+ return false ;
1964+ }
1965+ if (op->src [1 ]->type == GGML_TYPE_F32) {
1966+ return true ;
1967+ }
1968+ // if (op->src[1]->type == GGML_TYPE_Q8_0) {
1969+ // return true;
1970+ // }
1971+ // may be possible if Q8_0 packed...
1972+ } else if (op->op == GGML_OP_MUL_MAT_ID
1973+ && op->src [0 ]->buffer
1974+ && (ggml_n_dims (op->src [0 ]) == 3 )
1975+ && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()
1976+ && ggml_repack_get_optimal_repack_type (op->src [0 ])
1977+ ) {
1978+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1979+ return false ;
1980+ }
1981+ if (op->src [1 ]->type == GGML_TYPE_F32) {
1982+ return true ;
1983+ }
1984+ // if (op->src[1]->type == GGML_TYPE_Q8_0) {
1985+ // return true;
1986+ // }
1987+ }
1988+ return false ;
19591989 }
19601990
19611991 ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
0 commit comments