@@ -648,39 +648,62 @@ extern "C" {
648648 extern __thread int ggml_current_numa_node ;
649649#endif
650650
651- static inline void * tensor_data (const struct ggml_tensor * tensor ) {
652- #ifdef GGML_NUMA_MIRROR
653- int n = ggml_current_numa_node ;
654- if (n == -1 )
655- n = 0 ;
656- return tensor -> __data [n ];
657- #else
658- return tensor -> data ;
659- #endif
660- }
651+ #define tensor_data (tensor ) \
652+ _Generic((tensor), \
653+ struct ggml_tensor*: _tensor_data_ptr(tensor), \
654+ const struct ggml_tensor*: _tensor_data_ptr(tensor), \
655+ default: _tensor_data_instance(tensor) \
656+ )
657+
658+ #define tensor_set_data (tensor , value ) \
659+ _Generic((tensor), \
660+ struct ggml_tensor*: _tensor_set_data_ptr(tensor, value), \
661+ default: _tensor_set_data_instance(tensor, value) \
662+ )
661663
662- static inline void tensor_set_data (struct ggml_tensor * tensor , void * data ) {
663664#ifdef GGML_NUMA_MIRROR
664- if ((uint64_t )data >= \
665- GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
666- GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
667- (uint64_t )data < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
668- 2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT ) {
669- data = (void * ) ((uint64_t )data - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT );
670- }
671- tensor -> __data [0 ] = data ;
672- if ((uint64_t )data >= \
673- GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
674- (uint64_t )data < \
675- GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + \
676- GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT ) {
677- tensor -> __data [1 ] = (void * ) ((uint64_t )data + \
678- GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT );
679- } else {
680- tensor -> __data [1 ] = data ;
681- }
665+ #define _tensor_data_ptr (tensor ) \
666+ (ggml_current_numa_node == -1 ? (tensor)->__data[0] : (tensor)->__data[ggml_current_numa_node])
667+
668+ #define _tensor_data_instance (tensor ) \
669+ (ggml_current_numa_node == -1 ? (tensor).__data[0] : (tensor).__data[ggml_current_numa_node])
670+
671+ #define _tensor_set_data_ptr (tensor , data_ptr ) \
672+ do { \
673+ void* data_ = (data_ptr); \
674+ if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
675+ (uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + 2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
676+ data_ = (void*)((uint64_t)data_ - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
677+ } \
678+ (tensor)->__data[0] = data_; \
679+ if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
680+ (uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
681+ (tensor)->__data[1] = (void*)((uint64_t)data_ + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
682+ } else { \
683+ (tensor)->__data[1] = data_; \
684+ } \
685+ } while (0)
686+
687+ #define _tensor_set_data_instance (tensor , data_ptr ) \
688+ do { \
689+ void* data_ = (data_ptr); \
690+ if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT && \
691+ (uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + 2 * GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
692+ data_ = (void*)((uint64_t)data_ - GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
693+ } \
694+ (tensor).__data[0] = data_; \
695+ if ((uint64_t)data_ >= GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET && \
696+ (uint64_t)data_ < GGML_MMAP_VIRTUAL_MEMORY_BASE_OFFSET + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT) { \
697+ (tensor).__data[1] = (void*)((uint64_t)data_ + GGML_MMAP_VIRTUAL_MEMORY_NUMA_INCREMENT); \
698+ } else { \
699+ (tensor).__data[1] = data_; \
700+ } \
701+ } while (0)
682702#else
683- tensor -> data = data ;
703+ #define _tensor_data_ptr (tensor ) ((tensor)->data)
704+ #define _tensor_data_instance (tensor ) ((tensor).data)
705+ #define _tensor_set_data_ptr (tensor , value ) ((tensor)->data = (value))
706+ #define _tensor_set_data_instance (tensor , value ) ((tensor).data = (value))
684707#endif
685708 }
686709
0 commit comments