Skip to content

Commit 14bfbf8

Browse files
committed
make a smarter macro for tensor_data / tensor_set_data to handle both instance and pointer struct member accesses
1 parent 9b8e73f commit 14bfbf8

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

ggml/include/ggml.h

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -648,39 +648,62 @@ extern "C" {
648648
extern __thread int ggml_current_numa_node;
649649
#endif
650650

651-
static inline void * tensor_data(const struct ggml_tensor * tensor) {
652-
#ifdef GGML_NUMA_MIRROR
653-
int n = ggml_current_numa_node;
654-
if (n == -1)
655-
n = 0;
656-
return tensor->__data[n];
657-
#else
658-
return tensor->data;
659-
#endif
660-
}
651+
#define tensor_data(tensor) \
652+
_Generic((tensor), \
653+
struct ggml_tensor*: _tensor_data_ptr(tensor), \
654+
const struct ggml_tensor*: _tensor_data_ptr(tensor), \
655+
default: _tensor_data_instance(tensor) \
656+
)
657+
658+
#define tensor_set_data(tensor, value) \
659+
_Generic((tensor), \
660+
struct ggml_tensor*: _tensor_set_data_ptr(tensor, value), \
661+
default: _tensor_set_data_instance(tensor, value) \
662+
)
661663

662-
static inline void tensor_set_data(struct ggml_tensor * tensor, void * data) {
663664
#ifdef GGML_NUMA_MIRROR
664-
if ((uint64_t)data >= \
665-
GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
666-
GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
667-
(uint64_t)data < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
668-
2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) {
669-
data = (void*) ((uint64_t)data - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT);
670-
}
671-
tensor->__data[0] = data;
672-
if ((uint64_t)data >= \
673-
GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
674-
(uint64_t)data < \
675-
GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
676-
GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) {
677-
tensor->__data[1] = (void*) ((uint64_t)data + \
678-
GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT);
679-
} else {
680-
tensor->__data[1] = data;
681-
}
665+
#define _tensor_data_ptr(tensor) \
666+
(ggml_current_numa_node == -1 ? (tensor)->__data[0] : (tensor)->__data[ggml_current_numa_node])
667+
668+
#define _tensor_data_instance(tensor) \
669+
(ggml_current_numa_node == -1 ? (tensor).__data[0] : (tensor).__data[ggml_current_numa_node])
670+
671+
#define _tensor_set_data_ptr(tensor, data_ptr) \
672+
do { \
673+
void* data_ = (data_ptr); \
674+
if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
675+
(uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + 2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
676+
data_ = (void*)((uint64_t)data_ - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
677+
} \
678+
(tensor)->__data[0] = data_; \
679+
if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
680+
(uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
681+
(tensor)->__data[1] = (void*)((uint64_t)data_ + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
682+
} else { \
683+
(tensor)->__data[1] = data_; \
684+
} \
685+
} while (0)
686+
687+
#define _tensor_set_data_instance(tensor, data_ptr) \
688+
do { \
689+
void* data_ = (data_ptr); \
690+
if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
691+
(uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + 2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
692+
data_ = (void*)((uint64_t)data_ - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
693+
} \
694+
(tensor).__data[0] = data_; \
695+
if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
696+
(uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
697+
(tensor).__data[1] = (void*)((uint64_t)data_ + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
698+
} else { \
699+
(tensor).__data[1] = data_; \
700+
} \
701+
} while (0)
682702
#else
683-
tensor->data = data;
703+
#define _tensor_data_ptr(tensor) ((tensor)->data)
704+
#define _tensor_data_instance(tensor) ((tensor).data)
705+
#define _tensor_set_data_ptr(tensor, value) ((tensor)->data = (value))
706+
#define _tensor_set_data_instance(tensor, value) ((tensor).data = (value))
684707
#endif
685708
}
686709

ggml/src/ggml-cpu/ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6861,7 +6861,7 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
68616861
src1.nb[1] = k * traits->type_size;
68626862
src1.nb[2] = src1.nb[1];
68636863
src1.nb[3] = src1.nb[2];
6864-
src1.data = a;
6864+
tensor_set_data(src1, a);
68656865

68666866
struct ggml_tensor src0 = {};
68676867
src0.type = type;
@@ -6873,7 +6873,7 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
68736873
src0.nb[1] = k * traits->type_size;
68746874
src0.nb[2] = src0.nb[1];
68756875
src0.nb[3] = src0.nb[2];
6876-
src0.data = b;
6876+
tensor_set_data(src0, b);
68776877

68786878
struct ggml_tensor dst = {};
68796879
dst.ne[0] = n;
@@ -6884,7 +6884,7 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
68846884
dst.nb[1] = n * sizeof(float);
68856885
dst.nb[2] = dst.nb[1];
68866886
dst.nb[3] = dst.nb[2];
6887-
dst.data = c;
6887+
tensor_set_data(dst, c);
68886888
dst.src[0] = &src0;
68896889
dst.src[1] = &src1;
68906890

@@ -7151,7 +7151,7 @@ static void ggml_compute_forward_conv_2d_dw_cwhn(
71517151
const ggml_conv_2d_dw_params & p) {
71527152

71537153
const int64_t c = p.channels;
7154-
const float * knl_data = (const float *)tensor_data(kernel)
7154+
const float * knl_data = (const float *)tensor_data(kernel);
71557155

71567156
const int64_t rows_total = p.dst_h * p.batch;
71577157
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,7 +2164,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21642164
src0_slice.nb[3] = src0_slice.nb[2];
21652165
src0_slice.op = GGML_OP_VIEW;
21662166
src0_slice.view_src = dst->src[0]; // non-const pointer to src0
2167-
src0_slice.data = (char *) tensor_data(src0) + i02*nb02;
2167+
tensor_set_data(src0_slice, (char *) tensor_data(src0) + i02*nb02);
21682168

21692169
ggml_tensor src1_slice;
21702170
memset(&src1_slice, 0, sizeof(src1_slice));
@@ -2178,7 +2178,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21782178
src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0];
21792179
src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1];
21802180
src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2];
2181-
src1_slice.data = src1_data_cur;
2181+
tensor_set_data(src1_slice, src1_data_cur);
21822182

21832183
ggml_tensor dst_slice;
21842184
memset(&dst_slice, 0, sizeof(dst_slice));
@@ -2192,7 +2192,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21922192
dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0];
21932193
dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1];
21942194
dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2];
2195-
dst_slice.data = dst_data_cur;
2195+
tensor_set_data(dst_slice, dst_data_cur);
21962196

21972197
ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
21982198
CUDA_CHECK(cudaGetLastError());

ggml/src/ggml-cuda/gla.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ static __global__ void gated_linear_attn_f32(const int B, const int T, const int
6262
}
6363

6464
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
65-
const float * k_d = (const float *)tensor_data(dst->src[0])a;
65+
const float * k_d = (const float *)tensor_data(dst->src[0]);
6666
const float * v_d = (const float *)tensor_data(dst->src[1]);
6767
const float * r_d = (const float *)tensor_data(dst->src[2]);
6868
const float * td_d = (const float *)tensor_data(dst->src[3]);

0 commit comments

Comments
 (0)