Skip to content

Commit 6551d7e

Browse files
committed
Add "Footgun API"
This adds the `lbt_set_forward()` and `lbt_get_forward()` APIs, which allow directly setting the BLAS API functions that a user might be interested in wrapping. It also adds `lbt_set_default_function` and provides a fairly useless (but at least not fatal) default value for the default function so that if users don't want to seqfault when calling an uninitialized function, they won't.
1 parent 50bde0e commit 6551d7e

File tree

6 files changed

+305
-112
lines changed

6 files changed

+305
-112
lines changed

src/deepbindless_surrogates.c

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,18 @@
22

33
#ifdef LBT_DEEPBINDLESS
44

5-
int find_symbol_idx(const char * name) {
6-
for (int symbol_idx=0; exported_func_names[symbol_idx] != NULL; ++symbol_idx) {
7-
if (strcmp(exported_func_names[symbol_idx], "lsame_") == 0) {
8-
return symbol_idx;
9-
}
10-
}
11-
12-
// This is fatal as it signifies a configuration error in our trampoline symbol list
13-
fprintf(stderr, "Error: Unable to find %s in our symbol list?!\n", name);
14-
exit(1);
15-
}
16-
175
int lsame_idx = -1;
186
const void *old_lsame32 = NULL, *old_lsame64 = NULL;
197
void push_fake_lsame() {
208
// Find `lsame_` in our symbol list (if we haven't done so before)
21-
if (lsame_idx == -1)
9+
if (lsame_idx == -1) {
2210
lsame_idx = find_symbol_idx("lsame_");
11+
if (lsame_idx == -1) {
12+
// This is fatal as it signifies a configuration error in our trampoline symbol list
13+
fprintf(stderr, "Error: Unable to find lsame_ in our symbol list?!\n");
14+
exit(1);
15+
}
16+
}
2317

2418
// Save old values of `lsame_` and `lsame_64_` to our swap location
2519
old_lsame32 = (*exported_func32_addrs[lsame_idx]);

src/libblastrampoline.c

Lines changed: 135 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,138 @@
1010
#define DEEPBINDLESS_INTERFACE_ILP64_LOADED 0x02
1111
uint8_t deepbindless_interfaces_loaded = 0x00;
1212

13+
14+
int32_t find_symbol_idx(const char * name) {
15+
for (int32_t symbol_idx=0; exported_func_names[symbol_idx] != NULL; ++symbol_idx) {
16+
if (strcmp(exported_func_names[symbol_idx], name) == 0) {
17+
return symbol_idx;
18+
}
19+
}
20+
return -1;
21+
}
22+
23+
24+
LBT_DLLEXPORT void lbt_default_func_print_error() {
25+
fprintf(stderr, "Error: no BLAS/LAPACK library loaded!\n");
26+
}
27+
const void * default_func = (const void *)&lbt_default_func_print_error;
28+
LBT_DLLEXPORT const void * lbt_get_default_func() {
29+
return default_func;
30+
}
31+
32+
LBT_DLLEXPORT void lbt_set_default_func(const void * addr) {
33+
default_func = addr;
34+
}
35+
36+
/*
37+
* Force a forward to a particular value.
38+
*/
39+
int32_t set_forward_by_index(int32_t symbol_idx, const void * addr, int32_t interface, int32_t f2c, int32_t verbose) {
40+
// Quit out immediately if this is not a interface setting
41+
if (interface != LBT_INTERFACE_LP64 && interface != LBT_INTERFACE_ILP64) {
42+
return -1;
43+
}
44+
45+
// NULL is a special value that means our "default address"... which may itself be `NULL`!
46+
if (addr == NULL) {
47+
addr = default_func;
48+
}
49+
50+
if (interface == LBT_INTERFACE_LP64) {
51+
(*exported_func32_addrs[symbol_idx]) = addr;
52+
} else {
53+
(*exported_func64_addrs[symbol_idx]) = addr;
54+
55+
// If we're on an RTLD_DEEPBINDless system and our workaround is activated,
56+
// we take over our own 32-bit symbols as well.
57+
if (deepbindless_interfaces_loaded & DEEPBINDLESS_INTERFACE_ILP64_LOADED) {
58+
(*exported_func32_addrs[symbol_idx]) = addr;
59+
}
60+
}
61+
62+
#ifdef F2C_AUTODETECTION
63+
if (f2c == LBT_F2C_REQUIRED) {
64+
// Check to see if this symbol is one of the f2c functions
65+
int f2c_symbol_idx = 0;
66+
for (f2c_symbol_idx=0; f2c_func_idxs[f2c_symbol_idx] != -1; ++f2c_symbol_idx) {
67+
// Jump through the f2c_func_idxs layer of indirection to find the `exported_func*_addrs` offsets
68+
// Skip any symbols that aren't ours
69+
if (f2c_func_idxs[f2c_symbol_idx] != symbol_idx)
70+
continue;
71+
72+
if (verbose) {
73+
char exported_name[MAX_SYMBOL_LEN];
74+
sprintf(exported_name, "%s%s", exported_func_names[symbol_idx], interface == LBT_INTERFACE_ILP64 ? "64_" : "");
75+
printf(" - [%04d] f2c(%s)\n", symbol_idx, exported_name);
76+
}
77+
78+
// Override these addresses with our f2c wrappers
79+
if (interface == LBT_INTERFACE_LP64) {
80+
// Save "true" symbol address in `f2c_$(name)_addr`, then set our exported `$(name)` symbol
81+
// to call `f2c_$(name)`, which will bounce into the true symbol, but fix the return value.
82+
(*f2c_func32_addrs[f2c_symbol_idx]) = (*exported_func32_addrs[symbol_idx]);
83+
(*exported_func32_addrs[symbol_idx]) = f2c_func32_wrappers[f2c_symbol_idx];
84+
} else {
85+
(*f2c_func64_addrs[f2c_symbol_idx]) = (*exported_func64_addrs[symbol_idx]);
86+
(*exported_func64_addrs[symbol_idx]) = f2c_func64_wrappers[f2c_symbol_idx];
87+
}
88+
}
89+
}
90+
#endif // F2C_AUTODETECTION
91+
return 0;
92+
}
93+
94+
LBT_DLLEXPORT const void * lbt_get_forward(const char * symbol_name, int32_t interface, int32_t f2c) {
95+
// Search symbol list for `symbol_name`, then sub off to `set_forward_by_index()`
96+
int32_t symbol_idx = find_symbol_idx(symbol_name);
97+
if (symbol_idx == -1)
98+
return (const void *)-1;
99+
100+
#ifdef F2C_AUTODETECTION
101+
if (f2c == LBT_F2C_REQUIRED) {
102+
// Check to see if this symbol is one of the f2c functions
103+
int f2c_symbol_idx = 0;
104+
for (f2c_symbol_idx=0; f2c_func_idxs[f2c_symbol_idx] != -1; ++f2c_symbol_idx) {
105+
// Skip any symbols that aren't ours
106+
if (f2c_func_idxs[f2c_symbol_idx] != symbol_idx)
107+
continue;
108+
109+
// If we find it, return the "true" address, but only if the currently-exported
110+
// address is actually our f2c wrapper; if it's not then do nothing.
111+
if (interface == LBT_INTERFACE_LP64) {
112+
if (*exported_func32_addrs[symbol_idx] == f2c_func32_wrappers[f2c_symbol_idx]) {
113+
return (const void *)(*f2c_func32_addrs[f2c_symbol_idx]);
114+
}
115+
} else {
116+
if (*exported_func64_addrs[symbol_idx] == f2c_func64_wrappers[f2c_symbol_idx]) {
117+
return (const void *)(*f2c_func64_addrs[f2c_symbol_idx]);
118+
}
119+
}
120+
}
121+
}
122+
#endif
123+
124+
// If we're not in f2c-hell, we can just return our interface's address directly.
125+
if (interface == LBT_INTERFACE_LP64) {
126+
return (const void *)(*exported_func32_addrs[symbol_idx]);
127+
} else {
128+
return (const void *)(*exported_func64_addrs[symbol_idx]);
129+
}
130+
}
131+
132+
LBT_DLLEXPORT int32_t lbt_set_forward(const char * symbol_name, const void * addr, int32_t interface, int32_t f2c, int32_t verbose) {
133+
// Search symbol list for `symbol_name`, then sub off to `set_forward_by_index()`
134+
int32_t symbol_idx = find_symbol_idx(symbol_name);
135+
if (symbol_idx == -1)
136+
return -1;
137+
138+
return set_forward_by_index(symbol_idx, addr, interface, f2c, verbose);
139+
}
140+
13141
/*
14142
* Load `libname`, clearing previous mappings if `clear` is set.
15143
*/
16-
LBT_DLLEXPORT int lbt_forward(const char * libname, int clear, int verbose) {
144+
LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t verbose) {
17145
if (verbose) {
18146
printf("Generating forwards to %s\n", libname);
19147
}
@@ -131,67 +259,25 @@ LBT_DLLEXPORT int lbt_forward(const char * libname, int clear, int verbose) {
131259
}
132260

133261
// Finally, re-export its symbols:
134-
int nforwards = 0;
135-
int symbol_idx = 0;
262+
int32_t nforwards = 0;
263+
int32_t symbol_idx = 0;
136264
char symbol_name[MAX_SYMBOL_LEN];
137265
for (symbol_idx=0; exported_func_names[symbol_idx] != NULL; ++symbol_idx) {
138266
// If `clear` is set, zero out all symbols that may have been set so far
139267
if (clear) {
140-
(*exported_func32_addrs[symbol_idx]) = NULL;
141-
(*exported_func64_addrs[symbol_idx]) = NULL;
268+
(*exported_func32_addrs[symbol_idx]) = default_func;
269+
(*exported_func64_addrs[symbol_idx]) = default_func;
142270
}
143271

144272
// Look up this symbol in the given library, if it is a valid symbol, set it!
145273
sprintf(symbol_name, "%s%s", exported_func_names[symbol_idx], lib_suffix);
146274
void *addr = lookup_symbol(handle, symbol_name);
147275
if (addr != NULL) {
148-
if (verbose) {
149-
char exported_name[MAX_SYMBOL_LEN];
150-
sprintf(exported_name, "%s%s", exported_func_names[symbol_idx], interface == LBT_INTERFACE_ILP64 ? "64_" : "");
151-
printf(" - [%04d] %s -> %s [%p]\n", symbol_idx, exported_name, symbol_name, addr);
152-
}
153-
if (interface == LBT_INTERFACE_LP64) {
154-
(*exported_func32_addrs[symbol_idx]) = addr;
155-
} else {
156-
(*exported_func64_addrs[symbol_idx]) = addr;
157-
158-
// If we're on an RTLD_DEEPBINDless system and our workaround is activated,
159-
// we take over our own 32-bit symbols as well.
160-
if (deepbindless_interfaces_loaded & DEEPBINDLESS_INTERFACE_ILP64_LOADED) {
161-
(*exported_func32_addrs[symbol_idx]) = addr;
162-
}
163-
}
276+
set_forward_by_index(symbol_idx, addr, interface, f2c, verbose);
164277
nforwards++;
165278
}
166279
}
167280

168-
#ifdef F2C_AUTODETECTION
169-
if (f2c == LBT_F2C_REQUIRED) {
170-
int f2c_symbol_idx = 0;
171-
for (f2c_symbol_idx=0; f2c_func_idxs[f2c_symbol_idx] != -1; ++f2c_symbol_idx) {
172-
// Jump through the f2c_func_idxs layer of indirection to find the `exported_func*_addrs` offsets
173-
symbol_idx = f2c_func_idxs[f2c_symbol_idx];
174-
175-
if (verbose) {
176-
char exported_name[MAX_SYMBOL_LEN];
177-
sprintf(exported_name, "%s%s", exported_func_names[symbol_idx], interface == LBT_INTERFACE_ILP64 ? "64_" : "");
178-
printf(" - [%04d] f2c(%s)\n", symbol_idx, exported_name);
179-
}
180-
181-
// Override these addresses with our f2c wrappers
182-
if (interface == LBT_INTERFACE_LP64) {
183-
// Save "true" symbol address in `f2c_$(name)_addr`, then set our exported `$(name)` symbol
184-
// to call `f2c_$(name)`, which will bounce into the true symbol, but fix the return value.
185-
(*f2c_func32_addrs[f2c_symbol_idx]) = (*exported_func32_addrs[symbol_idx]);
186-
(*exported_func32_addrs[symbol_idx]) = f2c_func32_wrappers[f2c_symbol_idx];
187-
} else {
188-
(*f2c_func64_addrs[f2c_symbol_idx]) = (*exported_func64_addrs[symbol_idx]);
189-
(*exported_func64_addrs[symbol_idx]) = f2c_func64_wrappers[f2c_symbol_idx];
190-
}
191-
}
192-
}
193-
#endif
194-
195281
record_library_load(libname, handle, lib_suffix, interface, f2c);
196282
if (verbose) {
197283
printf("Processed %d symbols; forwarded %d symbols with %d-bit interface and mangling to a suffix of \"%s\"\n", symbol_idx, nforwards, interface, lib_suffix);
@@ -200,6 +286,7 @@ LBT_DLLEXPORT int lbt_forward(const char * libname, int clear, int verbose) {
200286
return nforwards;
201287
}
202288

289+
203290
__attribute__((constructor)) void init(void) {
204291
// Initialize config structures
205292
init_config();

src/libblastrampoline.h

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ typedef struct {
9090
*
9191
* If `verbose` is set to a non-zero value, it will print out debugging information.
9292
*/
93-
int lbt_forward(const char * libname, int clear, int verbose);
93+
LBT_DLLEXPORT int32_t lbt_forward(const char * libname, int32_t clear, int32_t verbose);
9494

9595
/*
9696
* Returns a structure describing the currently-loaded libraries as well as the build configuration
9797
* of this `libblastrampoline` instance. See the definition of `lbt_config_t` in this header file
9898
* for more details.
9999
*/
100-
const lbt_config_t * lbt_get_config();
100+
LBT_DLLEXPORT const lbt_config_t * lbt_get_config();
101101

102102
/*
103103
* Returns the number of threads configured by the underlying BLAS library. In the event that
@@ -106,7 +106,7 @@ const lbt_config_t * lbt_get_config();
106106
* for the `lbt_register_thread_interface()` function, although many common functions (such as
107107
* those for `OpenBLAS`, `MKL` and `BLIS`) are already registered by default.
108108
*/
109-
int32_t lbt_get_num_threads();
109+
LBT_DLLEXPORT int32_t lbt_get_num_threads();
110110

111111
/*
112112
* Sets the number of threads in the underlying BLAS library. In the event that multiple
@@ -115,7 +115,7 @@ int32_t lbt_get_num_threads();
115115
* `lbt_register_thread_interface()` function, although many common functions (such as those
116116
* for `OpenBLAS`, `MKL` and `BLIS`) are already registered by default.
117117
*/
118-
void lbt_set_num_threads(int32_t num_threads);
118+
LBT_DLLEXPORT void lbt_set_num_threads(int32_t num_threads);
119119

120120
/*
121121
* Register a new `get_num_threads()`/`set_num_threads()` pair. These functions are assumed to be
@@ -130,7 +130,48 @@ void lbt_set_num_threads(int32_t num_threads);
130130
* name for this functionality, they must register those getter/setter functions here to have them
131131
* automatically called whenever `lbt_{get,set}_num_threads()` is called.
132132
*/
133-
void lbt_register_thread_interface(const char * getter, const char * setter);
133+
LBT_DLLEXPORT void lbt_register_thread_interface(const char * getter, const char * setter);
134+
135+
/*
136+
* Function that simply prints out to `stderr` that someone called an uninitialized function.
137+
* This is the default default function, see `lbt_set_default_func()` for how to override it.
138+
*/
139+
LBT_DLLEXPORT void lbt_default_func_print_error();
140+
141+
/*
142+
* Returns the currently-configured default function that gets called if no mapping has been set
143+
* for an exported symbol. Can return `NULL` if it was set as the default function.
144+
*/
145+
LBT_DLLEXPORT const void * lbt_get_default_func();
146+
147+
/*
148+
* Sets the default function that gets called if no mapping has been set for an exported symbol.
149+
* `NULL` is a valid address, if a segfault upon calling an uninitialized function is desirable.
150+
* Note that this will not be retroactively applied to already-set pointers, so you should call
151+
* this function immediately before calling `lbt_forward()` with `clear` set.
152+
*/
153+
LBT_DLLEXPORT void lbt_set_default_func(const void * addr);
154+
155+
/*
156+
* Returns the currently-configured forward target for the given `symbol_name`, according to the
157+
* requested `interface`. If `f2c` is set to `LBT_F2C_REQUIRED`, then if there is an f2c
158+
* workaround shim in effect for this symbol, this method will thread through that to return the
159+
* "true" symbol address. If `f2c` is set to any other value, then if there is an f2c workaround
160+
* shim in effect, the address of the shim will be returned. (This allows passing this address
161+
* to a 3rd party library which does not want to have to deal with f2c conversion, for instance).
162+
* If this is not an f2c-capable LBT build, `f2c` is ignored completely.
163+
*/
164+
LBT_DLLEXPORT const void * lbt_get_forward(const char * symbol_name, int32_t interface, int32_t f2c);
165+
166+
/*
167+
* Allows directly setting a symbol to be forwarded to a particular address, for the given
168+
* interface. If `f2c` is set to `LBT_F2C_REQUIRED` and this is an f2c-capable LBT build, an
169+
* f2c wrapper function will be interposed between the exported symbol and the targeted address.
170+
* If `verbose` is set to a non-zero value, status messages will be printed out to `stdout`.
171+
* If `addr` is set to `NULL` it will be set as the default function, see `lbt_set_default_func()`
172+
* for how to set the default function pointer.
173+
*/
174+
LBT_DLLEXPORT int32_t lbt_set_forward(const char * symbol_name, const void * addr, int32_t interface, int32_t f2c, int32_t verbose);
134175

135176
#ifdef __cplusplus
136177
} // extern "C"

src/libblastrampoline_internal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ extern const void ** exported_func64_addrs[];
4646
// The config type you get back from lbt_get_config()
4747
#define MAX_TRACKED_LIBS 31
4848

49+
// Functions in `libblastrampoline.c`
50+
int32_t find_symbol_idx(const char * name);
51+
4952
// Functions in `config.c`
5053
void init_config();
5154
void clear_loaded_libraries();

0 commit comments

Comments
 (0)