Skip to content

Commit db23ace

Browse files
committed
support float and double
1 parent 37870e9 commit db23ace

File tree

8 files changed

+94
-70
lines changed

8 files changed

+94
-70
lines changed

pyfmm/C_extension/Makefile

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,52 @@
1-
21
SRC_DIR = ./src
32
INC_DIR = ./include
43
BUILD_DIR = ./build
4+
BUILD_DIR_FLOAT = $(BUILD_DIR)/float
5+
BUILD_DIR_DOUBLE = $(BUILD_DIR)/double
56
LIB_DIR = ./lib
67
LIB_NAME = $(LIB_DIR)/libfmm
78
LIB_EXT = .so
89

9-
#
1010
CC = gcc
1111
CFLAGS = -O3 -g -ffast-math -march=native -mtune=native -fPIC -Wimplicit-function-declaration -I$(INC_DIR) -lm # -fsanitize=address -lasan
1212
LDFLAGS = -shared -lm # -lasan
13+
1314
SRCS = $(wildcard $(SRC_DIR)/*.c)
1415
INCS = $(wildcard $(INC_DIR)/*.h)
15-
OBJS = $(patsubst $(SRC_DIR)/%.c, $(BUILD_DIR)/%.o, $(SRCS))
16+
17+
# 不同版本的目标文件目录
18+
OBJS_FLOAT = $(patsubst $(SRC_DIR)/%.c, $(BUILD_DIR_FLOAT)/%.o, $(SRCS))
19+
OBJS_DOUBLE = $(patsubst $(SRC_DIR)/%.c, $(BUILD_DIR_DOUBLE)/%.o, $(SRCS))
20+
21+
# 生成的库名称
22+
TARGET_FLOAT = $(LIB_NAME)_float$(LIB_EXT)
23+
TARGET_DOUBLE = $(LIB_NAME)_double$(LIB_EXT)
1624

1725
.PHONY: all clean
1826

19-
TARGET = $(LIB_NAME)$(LIB_EXT)
27+
all: $(TARGET_FLOAT) $(TARGET_DOUBLE)
2028

21-
all: $(TARGET)
29+
# 链接动态库,生成 float 版本
30+
$(TARGET_FLOAT): $(OBJS_FLOAT)
31+
@mkdir -p $(LIB_DIR)
32+
$(CC) -o $@ $^ $(LDFLAGS)
2233

23-
# 链接动态库
24-
# 编译语句中顺序很关键,否则链接不上库
25-
$(TARGET): $(OBJS)
26-
@mkdir -p $(BUILD_DIR)
34+
# 链接动态库,生成 double 版本
35+
$(TARGET_DOUBLE): $(OBJS_DOUBLE)
2736
@mkdir -p $(LIB_DIR)
2837
$(CC) -o $@ $^ $(LDFLAGS)
29-
3038

31-
# 编译目标文件
32-
$(BUILD_DIR)/%.o: $(SRC_DIR)/%.c $(INCS)
33-
@mkdir -p $(BUILD_DIR)
34-
$(CC) -o $@ -c $< $(CFLAGS)
39+
# 编译 float 版本的目标文件
40+
$(BUILD_DIR_FLOAT)/%.o: $(SRC_DIR)/%.c $(INCS)
41+
@mkdir -p $(BUILD_DIR_FLOAT)
42+
$(CC) -o $@ -c $< $(CFLAGS) -DUSE_FLOAT
43+
44+
# 编译 double 版本的目标文件
45+
$(BUILD_DIR_DOUBLE)/%.o: $(SRC_DIR)/%.c $(INCS)
46+
@mkdir -p $(BUILD_DIR_DOUBLE)
47+
$(CC) -o $@ -c $< $(CFLAGS)
3548

3649
# 清理
3750
clean:
38-
rm -rf $(BUILD_DIR)
39-
rm -f $(TARGET)
51+
rm -rf $(BUILD_DIR_FLOAT) $(BUILD_DIR_DOUBLE)
52+
rm -f $(TARGET_FLOAT) $(TARGET_DOUBLE)

pyfmm/C_extension/include/const.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@
1515
#define DEG1 0.017453292519943295 ///< \f$ \frac{\pi}{180} \f$
1616
#define KM1DEG 111.194926644 ///< \f$ R_{e} \times \frac{\pi}{180} \f$
1717

18-
19-
typedef double MYREAL; ///< 单精度 or 双精度
18+
#ifdef USE_FLOAT
19+
typedef float MYREAL; ///< 单精度 or 双精度
20+
#else
21+
typedef double MYREAL;
22+
#endif
125 KB
Binary file not shown.

pyfmm/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from .c_interfaces import load_c_lib
2+
3+
# 默认使用双精度
4+
load_c_lib(use_float=False)
15

26
from . import traveltime
37
from .traveltime import *
48

59
from . import c_interfaces
610

711

8-
from ._version import __version__
12+
from ._version import __version__

pyfmm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.2'
1+
__version__ = '0.2.3'

pyfmm/c_interfaces.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,58 +9,62 @@
99

1010
import os
1111
from ctypes import *
12+
from typing import Any
1213

13-
__all__ = [
14-
'USE_FLOAT',
15-
'NPCT_REAL_TYPE',
16-
'C_FastMarching',
17-
'C_FMM_raytracing'
18-
]
14+
PDOUBLE = POINTER(c_double)
15+
PFLOAT = POINTER(c_float)
16+
PINT = POINTER(c_int)
1917

20-
USE_FLOAT = False
18+
USE_FLOAT:bool = False
2119
"""libfmm库中走时和慢度数组是否使用单精度浮点数"""
20+
NPCT_REAL_TYPE:str = 'f8'
2221

23-
NPCT_REAL_TYPE = 'f4' if USE_FLOAT else 'f8'
22+
C_FastMarching:Any = None
23+
C_FMM_raytracing:Any = None
2424

25-
REAL = c_float if USE_FLOAT else c_double
26-
PREAL = POINTER(REAL)
27-
PDOUBLE= POINTER(c_double)
28-
PINT= POINTER(c_int)
25+
def load_c_lib(use_float:bool=False):
26+
global USE_FLOAT, NPCT_REAL_TYPE, C_FastMarching, C_FMM_raytracing
2927

28+
USE_FLOAT = use_float
29+
NPCT_REAL_TYPE = 'f4' if USE_FLOAT else 'f8'
30+
_suffix = 'float' if USE_FLOAT else 'double'
3031

31-
libfmm = cdll.LoadLibrary(
32-
os.path.join(
33-
os.path.abspath(os.path.dirname(__file__)),
34-
"C_extension/lib/libfmm.so"))
35-
"""libfmm库"""
32+
REAL = c_float if USE_FLOAT else c_double
33+
PREAL = POINTER(REAL)
3634

35+
libfmm = cdll.LoadLibrary(
36+
os.path.join(
37+
os.path.abspath(os.path.dirname(__file__)),
38+
f"C_extension/lib/libfmm_{_suffix}.so"))
39+
"""libfmm库"""
3740

38-
C_FastMarching = libfmm.FastMarching
39-
"""C库中计算走时场的主函数 FastMarching, 详见C API同名函数"""
4041

41-
C_FMM_raytracing = libfmm.FMM_raytracing
42-
"""C库中根据走时场进行射线追踪 FMM_raytracing, 详见C API同名函数"""
42+
C_FastMarching = libfmm.FastMarching
43+
"""C库中计算走时场的主函数 FastMarching, 详见C API同名函数"""
4344

45+
C_FMM_raytracing = libfmm.FMM_raytracing
46+
"""C库中根据走时场进行射线追踪 FMM_raytracing, 详见C API同名函数"""
4447

45-
C_FastMarching.restype = None
46-
C_FastMarching.argtypes = [
47-
PDOUBLE, c_int,
48-
PDOUBLE, c_int,
49-
PDOUBLE, c_int,
50-
c_double, c_double, c_double,
51-
c_int, PREAL,
52-
PREAL, c_bool,
53-
c_int, c_int, c_bool
54-
]
5548

49+
C_FastMarching.restype = None
50+
C_FastMarching.argtypes = [
51+
PDOUBLE, c_int,
52+
PDOUBLE, c_int,
53+
PDOUBLE, c_int,
54+
c_double, c_double, c_double,
55+
c_int, PREAL,
56+
PREAL, c_bool,
57+
c_int, c_int, c_bool
58+
]
5659

57-
C_FMM_raytracing.restype = c_float
58-
C_FMM_raytracing.argtypes = [
59-
PDOUBLE, c_int,
60-
PDOUBLE, c_int,
61-
PDOUBLE, c_int,
62-
c_double, c_double, c_double,
63-
c_double, c_double, c_double, c_double, c_double,
64-
PREAL, c_bool,
65-
PDOUBLE, PINT
66-
]
60+
61+
C_FMM_raytracing.restype = c_float
62+
C_FMM_raytracing.argtypes = [
63+
PDOUBLE, c_int,
64+
PDOUBLE, c_int,
65+
PDOUBLE, c_int,
66+
c_double, c_double, c_double,
67+
c_double, c_double, c_double, c_double, c_double,
68+
PREAL, c_bool,
69+
PDOUBLE, PINT
70+
]

pyfmm/traveltime.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ctypes import byref, c_int
1313
from scipy import interpolate
1414

15-
from .c_interfaces import *
15+
from . import c_interfaces
1616

1717

1818
def travel_time_source(
@@ -56,14 +56,14 @@ def travel_time_source(
5656
c_xarr = npct.as_ctypes(xarr.astype('f8'))
5757
c_yarr = npct.as_ctypes(yarr.astype('f8'))
5858
c_zarr = npct.as_ctypes(zarr.astype('f8'))
59-
slw_ravel = slw.ravel().astype(NPCT_REAL_TYPE)
59+
slw_ravel = slw.ravel().astype(c_interfaces.NPCT_REAL_TYPE)
6060
c_slw = npct.as_ctypes(slw_ravel)
6161

62-
TT = np.zeros_like(slw).astype(NPCT_REAL_TYPE)
62+
TT = np.zeros_like(slw).astype(c_interfaces.NPCT_REAL_TYPE)
6363
TT_ravel = TT.ravel()
6464
c_TT = npct.as_ctypes(TT_ravel)
6565

66-
C_FastMarching(
66+
c_interfaces.C_FastMarching(
6767
c_xarr, len(xarr),
6868
c_yarr, len(yarr),
6969
c_zarr, len(zarr),
@@ -103,13 +103,13 @@ def travel_time_iniTT(
103103
c_xarr = npct.as_ctypes(xarr.astype('f8'))
104104
c_yarr = npct.as_ctypes(yarr.astype('f8'))
105105
c_zarr = npct.as_ctypes(zarr.astype('f8'))
106-
slw_ravel = slw.ravel().astype(NPCT_REAL_TYPE)
106+
slw_ravel = slw.ravel().astype(c_interfaces.NPCT_REAL_TYPE)
107107
c_slw = npct.as_ctypes(slw_ravel)
108108

109-
TT_ravel = iniTT.ravel().astype(NPCT_REAL_TYPE)
109+
TT_ravel = iniTT.ravel().astype(c_interfaces.NPCT_REAL_TYPE)
110110
c_TT = npct.as_ctypes(TT_ravel)
111111

112-
C_FastMarching(
112+
c_interfaces.C_FastMarching(
113113
c_xarr, len(xarr),
114114
c_yarr, len(yarr),
115115
c_zarr, len(zarr),
@@ -144,7 +144,7 @@ def raytracing(
144144
:return: (接收点走时,形状为(ndots, 3)的射线坐标)
145145
'''
146146

147-
TT_ravel = TT.ravel().astype(NPCT_REAL_TYPE)
147+
TT_ravel = TT.ravel().astype(c_interfaces.NPCT_REAL_TYPE)
148148
c_TT = npct.as_ctypes(TT_ravel)
149149

150150
c_xarr = npct.as_ctypes(xarr.astype('f8'))
@@ -158,7 +158,7 @@ def raytracing(
158158
c_rays = npct.as_ctypes(rays)
159159
c_ndots = c_int(maxdots)
160160

161-
travt = C_FMM_raytracing(
161+
travt = c_interfaces.C_FMM_raytracing(
162162
c_xarr, len(xarr),
163163
c_yarr, len(yarr),
164164
c_zarr, len(zarr),

0 commit comments

Comments
 (0)