Skip to content

Commit 5787cfb

Browse files
committed
Make tracking/checking of context more obvious
When C++ code is called from Lua we need to know in which context we are, because calling some functions is only allowed in certain contexts.
1 parent 38c7325 commit 5787cfb

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

src/output-flex.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ static char const osm2pgsql_table_name[] = "osm2pgsql.table";
6969
static char const osm2pgsql_object_metatable[] = "osm2pgsql.object_metatable";
7070

7171
prepared_lua_function_t::prepared_lua_function_t(lua_State *lua_state,
72-
char const *name,
73-
int nresults)
72+
calling_context context,
73+
char const *name, int nresults)
7474
{
7575
int const index = lua_gettop(lua_state);
7676

@@ -80,6 +80,7 @@ prepared_lua_function_t::prepared_lua_function_t(lua_State *lua_state,
8080
m_index = index;
8181
m_name = name;
8282
m_nresults = nresults;
83+
m_calling_context = context;
8384
return;
8485
}
8586

@@ -540,6 +541,12 @@ std::size_t output_flex_t::get_way_nodes()
540541

541542
int output_flex_t::app_get_bbox()
542543
{
544+
if (m_calling_context != calling_context::process_node &&
545+
m_calling_context != calling_context::process_way) {
546+
throw std::runtime_error{"The function get_bbox() can only be called"
547+
" from process_node() or process_way()"};
548+
}
549+
543550
if (lua_gettop(lua_state()) > 1) {
544551
throw std::runtime_error{"No parameter(s) needed for get_box()"};
545552
}
@@ -726,9 +733,9 @@ void output_flex_t::setup_flex_table_columns(flex_table_t *table)
726733

727734
int output_flex_t::app_define_table()
728735
{
729-
if (m_context_node || m_context_way || m_context_relation) {
730-
throw std::runtime_error{"Tables have to be defined before calling any "
731-
"of the process callbacks"};
736+
if (m_calling_context != calling_context::main) {
737+
throw std::runtime_error{"Database tables have to be defined in the"
738+
" main Lua code, not in any of the callbacks"};
732739
}
733740

734741
luaL_checktype(lua_state(), 1, LUA_TTABLE);
@@ -791,8 +798,17 @@ int output_flex_t::table_tostring()
791798

792799
int output_flex_t::table_add_row()
793800
{
801+
if (m_calling_context != calling_context::process_node &&
802+
m_calling_context != calling_context::process_way &&
803+
m_calling_context != calling_context::process_relation) {
804+
throw std::runtime_error{
805+
"The function add_row() can only be called from the "
806+
"process_node/way/relation() functions"};
807+
}
808+
794809
if (lua_gettop(lua_state()) != 2) {
795-
throw std::runtime_error{"Need two parameters: The osm2pgsql.table and the row data"};
810+
throw std::runtime_error{
811+
"Need two parameters: The osm2pgsql.table and the row data"};
796812
}
797813

798814
auto &table_connection =
@@ -823,9 +839,6 @@ int output_flex_t::table_add_row()
823839
"Trying to add relation to table '{}'"_format(table.name())};
824840
}
825841
add_row(&table_connection, *m_context_relation);
826-
} else {
827-
throw std::runtime_error{"The add_row() function can only be called "
828-
"from inside a process function"};
829842
}
830843

831844
return 0;
@@ -1009,6 +1022,8 @@ void output_flex_t::call_lua_function(prepared_lua_function_t func,
10091022
{
10101023
std::lock_guard<std::mutex> guard{lua_mutex};
10111024

1025+
m_calling_context = func.context();
1026+
10121027
lua_pushvalue(lua_state(), func.index()); // the function to call
10131028
push_osm_object_to_lua_stack(
10141029
lua_state(), object,
@@ -1020,6 +1035,8 @@ void output_flex_t::call_lua_function(prepared_lua_function_t func,
10201035
"Failed to execute Lua function 'osm2pgsql.{}':"
10211036
" {}"_format(func.name(), lua_tostring(lua_state(), -1))};
10221037
}
1038+
1039+
m_calling_context = calling_context::main;
10231040
}
10241041

10251042
void output_flex_t::pending_way(osmid_t id)
@@ -1318,10 +1335,12 @@ void output_flex_t::init_lua(std::string const &filename)
13181335
// Check whether the process_* functions are available and store them on
13191336
// the Lua stack for fast access later
13201337
lua_getglobal(lua_state(), "osm2pgsql");
1321-
m_process_node = prepared_lua_function_t{lua_state(), "process_node"};
1322-
m_process_way = prepared_lua_function_t{lua_state(), "process_way"};
1323-
m_process_relation =
1324-
prepared_lua_function_t{lua_state(), "process_relation"};
1338+
m_process_node = prepared_lua_function_t{
1339+
lua_state(), calling_context::process_node, "process_node"};
1340+
m_process_way = prepared_lua_function_t{
1341+
lua_state(), calling_context::process_way, "process_way"};
1342+
m_process_relation = prepared_lua_function_t{
1343+
lua_state(), calling_context::process_relation, "process_relation"};
13251344

13261345
lua_remove(lua_state(), 1); // global "osm2pgsql"
13271346
}

src/output-flex.hpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ extern "C"
2828

2929
using idset_t = osmium::index::IdSetSmall<osmid_t>;
3030

31+
/**
32+
* When C++ code is called from the Lua code we sometimes need to know
33+
* in what context this happens. These are the possible contexts.
34+
*/
35+
enum class calling_context
36+
{
37+
main = 0, ///< In main context, i.e. the Lua script outside any callbacks
38+
process_node = 1, ///< In the process_node() callback
39+
process_way = 2, ///< In the process_way() callback
40+
process_relation = 3 ///< In the process_relation() callback
41+
};
42+
3143
/**
3244
* The flex output calls several user-defined Lua functions. They are
3345
* "prepared" by putting the function pointers on the Lua stack. Objects
@@ -48,8 +60,8 @@ class prepared_lua_function_t
4860
* \param name Name of the function.
4961
* \param nresults The number of results this function is supposed to have.
5062
*/
51-
prepared_lua_function_t(lua_State *lua_state, const char *name,
52-
int nresults = 0);
63+
prepared_lua_function_t(lua_State *lua_state, calling_context context,
64+
const char *name, int nresults = 0);
5365

5466
/// Return the index of the function on the Lua stack.
5567
int index() const noexcept { return m_index; }
@@ -60,13 +72,16 @@ class prepared_lua_function_t
6072
/// The number of results this function is expected to have.
6173
int nresults() const noexcept { return m_nresults; }
6274

75+
calling_context context() const noexcept { return m_calling_context; }
76+
6377
/// Is this function defined in the users Lua code?
6478
operator bool() const noexcept { return m_index != 0; }
6579

6680
private:
6781
char const *m_name = nullptr;
6882
int m_index = 0;
6983
int m_nresults = 0;
84+
calling_context m_calling_context = calling_context::main;
7085
};
7186

7287
class output_flex_t : public output_t
@@ -195,10 +210,13 @@ class output_flex_t : public output_t
195210

196211
std::size_t m_num_way_nodes = std::numeric_limits<std::size_t>::max();
197212

198-
bool m_in_stage2 = false;
199213
prepared_lua_function_t m_process_node;
200214
prepared_lua_function_t m_process_way;
201215
prepared_lua_function_t m_process_relation;
216+
217+
calling_context m_calling_context = calling_context::main;
218+
219+
bool m_in_stage2 = false;
202220
};
203221

204222
#endif // OSM2PGSQL_OUTPUT_FLEX_HPP

0 commit comments

Comments
 (0)