Skip to content

Commit fa58a65

Browse files
authored
Merge pull request #1214 from joto/prep-for-changes-multistage-proc
Prep for changes multistage proc
2 parents 0a6f795 + 5787cfb commit fa58a65

File tree

4 files changed

+99
-83
lines changed

4 files changed

+99
-83
lines changed

docs/flex.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ The following functions are defined:
3737
more control than the more convenient other functions.
3838
* `osm2pgsql.mark_way(id)`: Mark the OSM way with the specified id. This way
3939
will be processed (again) in stage 2.
40-
* `osm2pgsql.mark_relation(id)`: Mark the OSM relation with the specified id.
41-
This relation will be processed (again) in stage 2.
4240

4341
You are expected to define one or more of the following functions:
4442

@@ -226,10 +224,9 @@ a default transformation. These are the defaults:
226224

227225
## Stages
228226

229-
Osm2pgsql processes the data in up to two stages. You can mark ways or
230-
relations in stage 1 for processing in stage 2 by calling
231-
`osm2pgsql.mark_way(id)` or `osm2pgsql.mark_relation(id)`, respectively. If you
232-
don't mark any objects, nothing will be done in stage 2.
227+
Osm2pgsql processes the data in up to two stages. You can mark ways in stage 1
228+
for processing in stage 2 by calling `osm2pgsql.mark_way(id)`. If you don't
229+
mark any ways, nothing will be done in stage 2.
233230

234231
You can look at `osm2pgsql.stage` to see in which stage you are.
235232

src/init.lua

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ function osm2pgsql.mark_way(id)
3333
return osm2pgsql.mark('w', id)
3434
end
3535

36-
function osm2pgsql.mark_relation(id)
37-
return osm2pgsql.mark('r', id)
38-
end
39-
4036
function osm2pgsql.clamp(value, low, high)
4137
return math.min(math.max(value, low), high)
4238
end

src/output-flex.cpp

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,18 @@ static char const osm2pgsql_table_name[] = "osm2pgsql.table";
6969
static char const osm2pgsql_object_metatable[] = "osm2pgsql.object_metatable";
7070

7171
prepared_lua_function_t::prepared_lua_function_t(lua_State *lua_state,
72-
char const *name)
72+
calling_context context,
73+
char const *name, int nresults)
7374
{
7475
int const index = lua_gettop(lua_state);
7576

7677
lua_getfield(lua_state, 1, name);
7778

7879
if (lua_type(lua_state, -1) == LUA_TFUNCTION) {
7980
m_index = index;
81+
m_name = name;
82+
m_nresults = nresults;
83+
m_calling_context = context;
8084
return;
8185
}
8286

@@ -518,9 +522,7 @@ int output_flex_t::app_mark()
518522
osmium::object_id_type const id = luaL_checkinteger(lua_state(), 2);
519523

520524
if (type_name[0] == 'w') {
521-
m_stage2_ways_tracker->mark(id);
522-
} else if (type_name[0] == 'r') {
523-
m_stage2_rels_tracker->mark(id);
525+
m_stage2_way_ids->set(id);
524526
}
525527

526528
return 0;
@@ -539,6 +541,12 @@ std::size_t output_flex_t::get_way_nodes()
539541

540542
int output_flex_t::app_get_bbox()
541543
{
544+
if (m_calling_context != calling_context::process_node &&
545+
m_calling_context != calling_context::process_way) {
546+
throw std::runtime_error{"The function get_bbox() can only be called"
547+
" from process_node() or process_way()"};
548+
}
549+
542550
if (lua_gettop(lua_state()) > 1) {
543551
throw std::runtime_error{"No parameter(s) needed for get_box()"};
544552
}
@@ -725,9 +733,9 @@ void output_flex_t::setup_flex_table_columns(flex_table_t *table)
725733

726734
int output_flex_t::app_define_table()
727735
{
728-
if (m_context_node || m_context_way || m_context_relation) {
729-
throw std::runtime_error{"Tables have to be defined before calling any "
730-
"of the process callbacks"};
736+
if (m_calling_context != calling_context::main) {
737+
throw std::runtime_error{"Database tables have to be defined in the"
738+
" main Lua code, not in any of the callbacks"};
731739
}
732740

733741
luaL_checktype(lua_state(), 1, LUA_TTABLE);
@@ -790,8 +798,17 @@ int output_flex_t::table_tostring()
790798

791799
int output_flex_t::table_add_row()
792800
{
801+
if (m_calling_context != calling_context::process_node &&
802+
m_calling_context != calling_context::process_way &&
803+
m_calling_context != calling_context::process_relation) {
804+
throw std::runtime_error{
805+
"The function add_row() can only be called from the "
806+
"process_node/way/relation() functions"};
807+
}
808+
793809
if (lua_gettop(lua_state()) != 2) {
794-
throw std::runtime_error{"Need two parameters: The osm2pgsql.table and the row data"};
810+
throw std::runtime_error{
811+
"Need two parameters: The osm2pgsql.table and the row data"};
795812
}
796813

797814
auto &table_connection =
@@ -821,14 +838,7 @@ int output_flex_t::table_add_row()
821838
throw std::runtime_error{
822839
"Trying to add relation to table '{}'"_format(table.name())};
823840
}
824-
if (m_in_stage2) {
825-
delete_from_table(&table_connection, osmium::item_type::relation,
826-
m_context_relation->id());
827-
}
828841
add_row(&table_connection, *m_context_relation);
829-
} else {
830-
throw std::runtime_error{"The add_row() function can only be called "
831-
"from inside a process function"};
832842
}
833843

834844
return 0;
@@ -1012,16 +1022,21 @@ void output_flex_t::call_lua_function(prepared_lua_function_t func,
10121022
{
10131023
std::lock_guard<std::mutex> guard{lua_mutex};
10141024

1025+
m_calling_context = func.context();
1026+
10151027
lua_pushvalue(lua_state(), func.index()); // the function to call
10161028
push_osm_object_to_lua_stack(
10171029
lua_state(), object,
10181030
get_options()->extra_attributes); // the single argument
10191031

10201032
luaX_set_context(lua_state(), this);
1021-
if (luaX_pcall(lua_state(), 1, 0)) {
1022-
throw std::runtime_error{"Failed to execute lua processing function:"
1023-
" {}"_format(lua_tostring(lua_state(), -1))};
1033+
if (luaX_pcall(lua_state(), 1, func.nresults())) {
1034+
throw std::runtime_error{
1035+
"Failed to execute Lua function 'osm2pgsql.{}':"
1036+
" {}"_format(func.name(), lua_tostring(lua_state(), -1))};
10241037
}
1038+
1039+
m_calling_context = calling_context::main;
10251040
}
10261041

10271042
void output_flex_t::pending_way(osmid_t id)
@@ -1213,8 +1228,7 @@ output_flex_t::clone(std::shared_ptr<middle_query_t> const &mid,
12131228
{
12141229
return std::make_shared<output_flex_t>(
12151230
mid, *get_options(), copy_thread, true, m_lua_state, m_process_node,
1216-
m_process_way, m_process_relation, m_tables,
1217-
m_stage2_ways_tracker, m_stage2_rels_tracker);
1231+
m_process_way, m_process_relation, m_tables, m_stage2_way_ids);
12181232
}
12191233

12201234
output_flex_t::output_flex_t(
@@ -1224,11 +1238,9 @@ output_flex_t::output_flex_t(
12241238
prepared_lua_function_t process_way,
12251239
prepared_lua_function_t process_relation,
12261240
std::shared_ptr<std::vector<flex_table_t>> tables,
1227-
std::shared_ptr<id_tracker> ways_tracker,
1228-
std::shared_ptr<id_tracker> rels_tracker)
1241+
std::shared_ptr<idset_t> stage2_way_ids)
12291242
: output_t(mid, o), m_tables(std::move(tables)),
1230-
m_stage2_ways_tracker(std::move(ways_tracker)),
1231-
m_stage2_rels_tracker(std::move(rels_tracker)), m_copy_thread(copy_thread),
1243+
m_stage2_way_ids(std::move(stage2_way_ids)), m_copy_thread(copy_thread),
12321244
m_lua_state(std::move(lua_state)), m_builder(o.projection),
12331245
m_expire(o.expire_tiles_zoom, o.expire_tiles_max_bbox, o.projection),
12341246
m_buffer(32768, osmium::memory::Buffer::auto_grow::yes),
@@ -1323,21 +1335,20 @@ void output_flex_t::init_lua(std::string const &filename)
13231335
// Check whether the process_* functions are available and store them on
13241336
// the Lua stack for fast access later
13251337
lua_getglobal(lua_state(), "osm2pgsql");
1326-
m_process_node = prepared_lua_function_t{lua_state(), "process_node"};
1327-
m_process_way = prepared_lua_function_t{lua_state(), "process_way"};
1328-
m_process_relation =
1329-
prepared_lua_function_t{lua_state(), "process_relation"};
1338+
m_process_node = prepared_lua_function_t{
1339+
lua_state(), calling_context::process_node, "process_node"};
1340+
m_process_way = prepared_lua_function_t{
1341+
lua_state(), calling_context::process_way, "process_way"};
1342+
m_process_relation = prepared_lua_function_t{
1343+
lua_state(), calling_context::process_relation, "process_relation"};
13301344

13311345
lua_remove(lua_state(), 1); // global "osm2pgsql"
13321346
}
13331347

13341348
void output_flex_t::stage2_proc()
13351349
{
1336-
bool const has_marked_ways = !m_stage2_ways_tracker->empty();
1337-
bool const has_marked_rels = !m_stage2_rels_tracker->empty();
1338-
1339-
if (!has_marked_ways && !has_marked_rels) {
1340-
fmt::print(stderr, "Skipping stage 2 (no marked objects).\n");
1350+
if (m_stage2_way_ids->empty()) {
1351+
fmt::print(stderr, "Skipping stage 2 (no marked ways).\n");
13411352
return;
13421353
}
13431354

@@ -1349,10 +1360,7 @@ void output_flex_t::stage2_proc()
13491360
util::timer_t timer;
13501361

13511362
for (auto &table : m_table_connections) {
1352-
if ((has_marked_ways &&
1353-
table.table().matches_type(osmium::item_type::way)) ||
1354-
(has_marked_rels &&
1355-
table.table().matches_type(osmium::item_type::relation))) {
1363+
if (table.table().matches_type(osmium::item_type::way)) {
13561364
fmt::print(stderr, " Creating id index on table '{}'...\n",
13571365
table.table().name());
13581366
table.create_id_index();
@@ -1372,32 +1380,19 @@ void output_flex_t::stage2_proc()
13721380
lua_setfield(lua_state(), -2, "stage");
13731381
lua_pop(lua_state(), 1); // osm2pgsql
13741382

1375-
osmid_t id;
1376-
13771383
fmt::print(stderr, "Entering stage 2 processing of {} ways...\n"_format(
1378-
m_stage2_ways_tracker->size()));
1384+
m_stage2_way_ids->size()));
13791385

1380-
while (id_tracker::is_valid((id = m_stage2_ways_tracker->pop_mark()))) {
1386+
m_stage2_way_ids->sort_unique();
1387+
for (osmid_t const id : *m_stage2_way_ids) {
13811388
m_buffer.clear();
13821389
if (!m_mid->way_get(id, m_buffer)) {
13831390
continue;
13841391
}
13851392
auto &way = m_buffer.get<osmium::Way>(0);
13861393
way_add(&way);
13871394
}
1388-
1389-
fmt::print(stderr,
1390-
"Entering stage 2 processing of {} relations...\n"_format(
1391-
m_stage2_rels_tracker->size()));
1392-
1393-
while (id_tracker::is_valid((id = m_stage2_rels_tracker->pop_mark()))) {
1394-
m_rels_buffer.clear();
1395-
if (!m_mid->relation_get(id, m_rels_buffer)) {
1396-
continue;
1397-
}
1398-
auto const &relation = m_rels_buffer.get<osmium::Relation>(0);
1399-
relation_add(relation);
1400-
}
1395+
m_stage2_way_ids->clear();
14011396
}
14021397

14031398
void output_flex_t::merge_expire_trees(output_t *other)

src/output-flex.hpp

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
#include "flex-table.hpp"
88
#include "format.hpp"
99
#include "geom-transform.hpp"
10-
#include "id-tracker.hpp"
1110
#include "osmium-builder.hpp"
1211
#include "output.hpp"
1312
#include "table.hpp"
1413
#include "tagtransform.hpp"
1514

15+
#include <osmium/index/id_set.hpp>
1616
#include <osmium/osm/item_type.hpp>
1717

1818
extern "C"
@@ -26,6 +26,20 @@ extern "C"
2626
#include <utility>
2727
#include <vector>
2828

29+
using idset_t = osmium::index::IdSetSmall<osmid_t>;
30+
31+
/**
32+
* When C++ code is called from the Lua code we sometimes need to know
33+
* in what context this happens. These are the possible contexts.
34+
*/
35+
enum class calling_context
36+
{
37+
main = 0, ///< In main context, i.e. the Lua script outside any callbacks
38+
process_node = 1, ///< In the process_node() callback
39+
process_way = 2, ///< In the process_way() callback
40+
process_relation = 3 ///< In the process_relation() callback
41+
};
42+
2943
/**
3044
* The flex output calls several user-defined Lua functions. They are
3145
* "prepared" by putting the function pointers on the Lua stack. Objects
@@ -41,37 +55,49 @@ class prepared_lua_function_t
4155
/**
4256
* Get function with the name "osm2pgsql.name" from Lua and put pointer
4357
* to it on the Lua stack.
58+
*
59+
* \param lua_state Current Lua state.
60+
* \param name Name of the function.
61+
* \param nresults The number of results this function is supposed to have.
4462
*/
45-
prepared_lua_function_t(lua_State *lua_state, const char *name);
63+
prepared_lua_function_t(lua_State *lua_state, calling_context context,
64+
const char *name, int nresults = 0);
4665

4766
/// Return the index of the function on the Lua stack.
4867
int index() const noexcept { return m_index; }
4968

69+
/// The name of the function.
70+
char const* name() const noexcept { return m_name; }
71+
72+
/// The number of results this function is expected to have.
73+
int nresults() const noexcept { return m_nresults; }
74+
75+
calling_context context() const noexcept { return m_calling_context; }
76+
5077
/// Is this function defined in the users Lua code?
5178
operator bool() const noexcept { return m_index != 0; }
5279

5380
private:
81+
char const *m_name = nullptr;
5482
int m_index = 0;
83+
int m_nresults = 0;
84+
calling_context m_calling_context = calling_context::main;
5585
};
5686

5787
class output_flex_t : public output_t
5888
{
5989

6090
public:
61-
output_flex_t(std::shared_ptr<middle_query_t> const &mid,
62-
options_t const &options,
63-
std::shared_ptr<db_copy_thread_t> const &copy_thread,
64-
bool is_clone = false,
65-
std::shared_ptr<lua_State> lua_state = nullptr,
66-
prepared_lua_function_t process_node = {},
67-
prepared_lua_function_t process_way = {},
68-
prepared_lua_function_t process_relation = {},
69-
std::shared_ptr<std::vector<flex_table_t>> tables =
70-
std::make_shared<std::vector<flex_table_t>>(),
71-
std::shared_ptr<id_tracker> ways_tracker =
72-
std::make_shared<id_tracker>(),
73-
std::shared_ptr<id_tracker> rels_tracker =
74-
std::make_shared<id_tracker>());
91+
output_flex_t(
92+
std::shared_ptr<middle_query_t> const &mid, options_t const &options,
93+
std::shared_ptr<db_copy_thread_t> const &copy_thread,
94+
bool is_clone = false, std::shared_ptr<lua_State> lua_state = nullptr,
95+
prepared_lua_function_t process_node = {},
96+
prepared_lua_function_t process_way = {},
97+
prepared_lua_function_t process_relation = {},
98+
std::shared_ptr<std::vector<flex_table_t>> tables =
99+
std::make_shared<std::vector<flex_table_t>>(),
100+
std::shared_ptr<idset_t> stage2_way_ids = std::make_shared<idset_t>());
75101

76102
output_flex_t(output_flex_t const &) = delete;
77103
output_flex_t &operator=(output_flex_t const &) = delete;
@@ -166,8 +192,7 @@ class output_flex_t : public output_t
166192
std::shared_ptr<std::vector<flex_table_t>> m_tables;
167193
std::vector<table_connection_t> m_table_connections;
168194

169-
std::shared_ptr<id_tracker> m_stage2_ways_tracker;
170-
std::shared_ptr<id_tracker> m_stage2_rels_tracker;
195+
std::shared_ptr<idset_t> m_stage2_way_ids;
171196

172197
std::shared_ptr<db_copy_thread_t> m_copy_thread;
173198

@@ -185,10 +210,13 @@ class output_flex_t : public output_t
185210

186211
std::size_t m_num_way_nodes = std::numeric_limits<std::size_t>::max();
187212

188-
bool m_in_stage2 = false;
189213
prepared_lua_function_t m_process_node;
190214
prepared_lua_function_t m_process_way;
191215
prepared_lua_function_t m_process_relation;
216+
217+
calling_context m_calling_context = calling_context::main;
218+
219+
bool m_in_stage2 = false;
192220
};
193221

194222
#endif // OSM2PGSQL_OUTPUT_FLEX_HPP

0 commit comments

Comments
 (0)