Skip to content

Commit 535c123

Browse files
committed
testing with sharrow
1 parent 11af951 commit 535c123

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

test/test_mtc.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,14 @@ def regress(ext, out_dir):
227227

228228

229229
@pytest.mark.parametrize(
230-
"chunk_training_mode,recode_pipeline_columns",
230+
"chunk_training_mode,recode_pipeline_columns,sharrow_enabled",
231231
[
232232
("disabled", True),
233233
("explicit", False),
234+
("explicit", True, True),
234235
],
235236
)
236-
def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
237+
def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns, sharrow_enabled):
237238
import activitysim.abm # register components # noqa: F401
238239

239240
out_dir = _test_path(f"output-progressive-recode{recode_pipeline_columns}")
@@ -267,8 +268,16 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
267268
],
268269
},
269270
"recode_pipeline_columns": recode_pipeline_columns,
271+
"trace_hh_id": 1196298,
270272
}
271273

274+
if sharrow_enabled and not recode_pipeline_columns:
275+
raise ValueError("sharrow_enabled requires recode_pipeline_columns")
276+
277+
if sharrow_enabled:
278+
settings["sharrow"] = "test"
279+
del settings["trace_hh_id"]
280+
272281
state = workflow.State.make_default(
273282
working_dir=working_dir,
274283
configs_dir=("ext-configs", "configs"),
@@ -279,11 +288,11 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
279288
)
280289
state.filesystem.persist_sharrow_cache()
281290
state.logging.config_logger()
282-
state.settings.trace_hh_id = 1196298
283291

284292
assert state.settings.models == EXPECTED_MODELS
285293
assert state.settings.chunk_size == 0
286-
assert not state.settings.sharrow
294+
if not sharrow_enabled:
295+
assert not state.settings.sharrow
287296

288297
ref_pipeline = Path(__file__).parent.joinpath(
289298
f"reference-pipeline-extended-recode{recode_pipeline_columns}.zip"
@@ -319,3 +328,7 @@ def test_mtc_extended_progressive(chunk_training_mode, recode_pipeline_columns):
319328

320329

321330
if __name__ == "__main__":
331+
test_mtc_extended_progressive("disabled", True, False)
332+
test_mtc_extended_progressive("explicit", False, False)
333+
test_mtc_extended_progressive("explicit", True, True)
334+

0 commit comments

Comments
 (0)