Skip to content

Commit 8aeee64

Browse files
willbkarlhigleyrnyak
authored
Add support for serializing modules involved in LambdaOp execution by value (#1741)
* Allow users to specify module serialization hints This commit adds an optional parameter to Workflow.save so that users can indicate that certain modules should be serialized by value. This is necessary if a LambdaOp in a workflow depends on a module whose source file will not be available in the deployment environment. Related to #1737. * Adds automatic inference of LambdaOp module dependencies This commit adds code to automatically infer LambdaOp module dependencies in several common cases: 1. in which a function is passed to LambdaOp by name, 2. in which the argument to LambdaOp is a lambda expression that refers to a function by a fully-qualified name, and 3. in which the argument to LambdaOp is a lambda expression that refers to a function via another variable The current implementation does not inspect the bodies of any function passed to LambdaOp, and many corner cases are (necessarily) omitted. However, this support should be complete enough to be useful for many users. Automatic inference is optional (via a parameter to Workflow.save) but it could be the default in the future. Related to issue #1737. * Added tests related to issue #1737 * Fix linter errors * Workflow.save: reset cloudpickle behavior changes on return * aligned formatting with black's expectations --------- Co-authored-by: Karl Higley <[email protected]> Co-authored-by: rnyak <[email protected]>
1 parent 7e1b198 commit 8aeee64

File tree

2 files changed

+175
-5
lines changed

2 files changed

+175
-5
lines changed

nvtabular/workflow/workflow.py

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
17+
import inspect
1618
import json
1719
import logging
1820
import sys
1921
import time
22+
import types
2023
import warnings
2124
from typing import TYPE_CHECKING, Optional
2225

@@ -31,9 +34,10 @@
3134

3235
from merlin.dag import Graph
3336
from merlin.dag.executors import DaskExecutor, LocalExecutor
37+
from merlin.dag.node import iter_nodes
3438
from merlin.io import Dataset
3539
from merlin.schema import Schema
36-
from nvtabular.ops import StatOperator
40+
from nvtabular.ops import LambdaOp, StatOperator
3741
from nvtabular.workflow.node import WorkflowNode
3842

3943
LOG = logging.getLogger("nvtabular")
@@ -255,13 +259,68 @@ def _transform_df(self, df):
255259

256260
return LocalExecutor().transform(df, self.output_node, self.output_dtypes)
257261

258-
def save(self, path):
262+
@classmethod
263+
def _getmodules(cls, fs):
264+
"""
265+
Returns an imprecise but useful approximation of the list of modules
266+
necessary to execute a given list of functions. This approximation is
267+
sound (all modules listed are required by the supplied functions) but not
268+
necessarily complete (not all modules required will necessarily be returned).
269+
270+
For function literals (lambda expressions), this returns
271+
1. the names of every module referenced in the lambda expression, e.g.,
272+
`m` for `lambda x: m.f(x)` and
273+
2. the names of the declaring module for every function referenced in
274+
the lambda expression, e.g. `m` for `import m.f; lambda x: f(x)`
275+
276+
For declared functions, this returns the names of their declaring modules.
277+
278+
The return value will exclude all built-in modules and (on Python 3.10 or later)
279+
all standard library modules.
280+
"""
281+
result = set()
282+
283+
exclusions = set(sys.builtin_module_names)
284+
if hasattr(sys, "stdlib_module_names"):
285+
# sys.stdlib_module_names is only available in Python 3.10 and beyond
286+
exclusions = exclusions | sys.stdlib_module_names
287+
288+
for f in fs:
289+
if f.__name__ == "<lambda>":
290+
for closurevars in [
291+
inspect.getclosurevars(f).globals,
292+
inspect.getclosurevars(f).nonlocals,
293+
]:
294+
for name, val in closurevars.items():
295+
print(f"{name} = {val}")
296+
if isinstance(val, types.ModuleType):
297+
result.add(val)
298+
elif isinstance(val, types.FunctionType):
299+
mod = inspect.getmodule(val)
300+
if mod is not None:
301+
result.add(mod)
302+
else:
303+
mod = inspect.getmodule(f)
304+
if mod is not None:
305+
result.add(mod)
306+
307+
return [mod for mod in result if mod.__name__ not in exclusions]
308+
309+
def save(self, path, modules_byvalue=None):
259310
"""Save this workflow to disk
260311
261312
Parameters
262313
----------
263314
path: str
264315
The path to save the workflow to
316+
modules_byvalue:
317+
A list of modules that should be serialized by value. This
318+
should include any modules that will not be available on
319+
the host where this workflow is ultimately deserialized.
320+
321+
In lieu of an explicit list, pass None to serialize all modules
322+
by reference or pass "auto" to use a heuristic to infer which
323+
modules to serialize by value.
265324
"""
266325
# avoid a circular import getting the version
267326
from nvtabular import __version__ as nvt_version
@@ -290,9 +349,51 @@ def save(self, path):
290349
o,
291350
)
292351

293-
# dump out the full workflow (graph/stats/operators etc) using cloudpickle
294-
with fs.open(fs.sep.join([path, "workflow.pkl"]), "wb") as o:
295-
cloudpickle.dump(self, o)
352+
# track existing by-value modules
353+
preexisting_modules_byvalue = set(cloudpickle.list_registry_pickle_by_value())
354+
355+
# direct cloudpickle to serialize selected modules by value
356+
if modules_byvalue is None:
357+
modules_byvalue = []
358+
elif modules_byvalue == "auto":
359+
l_nodes = self.graph.get_nodes_by_op_type(
360+
list(iter_nodes([self.graph.output_node])), LambdaOp
361+
)
362+
363+
try:
364+
modules_byvalue = Workflow._getmodules([ln.op.f for ln in l_nodes])
365+
except RuntimeError as ex:
366+
warnings.warn(
367+
"Failed to automatically infer modules to serialize by value. "
368+
f'Reason given was "{str(ex)}"'
369+
)
370+
371+
try:
372+
for m in modules_byvalue:
373+
if isinstance(m, types.ModuleType):
374+
cloudpickle.register_pickle_by_value(m)
375+
elif isinstance(m, str) and m in sys.modules:
376+
cloudpickle.register_pickle_by_value(sys.modules[m])
377+
except RuntimeError as ex:
378+
warnings.warn(
379+
f'Failed to register modules to serialize by value. Reason given was "{str(ex)}"'
380+
)
381+
382+
try:
383+
# dump out the full workflow (graph/stats/operators etc) using cloudpickle
384+
with fs.open(fs.sep.join([path, "workflow.pkl"]), "wb") as o:
385+
cloudpickle.dump(self, o)
386+
finally:
387+
# return all modules that we set to serialize by value to by-reference
388+
# (i.e., retain modules that were set to serialize by value before this invocation)
389+
390+
for m in modules_byvalue:
391+
if isinstance(m, types.ModuleType):
392+
if m.__name__ not in preexisting_modules_byvalue:
393+
cloudpickle.unregister_pickle_by_value(m)
394+
elif isinstance(m, str) and m in sys.modules:
395+
if m not in preexisting_modules_byvalue:
396+
cloudpickle.unregister_pickle_by_value(sys.modules[m])
296397

297398
@classmethod
298399
def load(cls, path, client=None) -> "Workflow":

tests/unit/workflow/test_workflow.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import shutil
21+
import sys
2122

2223
try:
2324
import cudf
@@ -666,3 +667,71 @@ def test_workflow_saved_schema(tmpdir):
666667
for node in postorder_iter_nodes(workflow2.output_node):
667668
assert node.input_schema is not None
668669
assert node.output_schema is not None
670+
671+
672+
def test_workflow_infer_modules_byvalue(tmp_path):
673+
module_fn = tmp_path / "not_a_real_module.py"
674+
sys.path.append(str(tmp_path))
675+
676+
with open(module_fn, "w") as module_f:
677+
module_f.write("def identity(col):\n return col")
678+
679+
import not_a_real_module
680+
681+
f_0 = not_a_real_module.identity
682+
f_1 = lambda x: not_a_real_module.identity(x) # noqa
683+
f_2 = lambda x: f_0(x) # noqa
684+
685+
try:
686+
for fn, f in {
687+
"not_a_real_module.identity": f_0,
688+
"lambda x: not_a_real_module.identity(x)": f_1,
689+
"lambda x: f_0(x)": f_2,
690+
}.items():
691+
assert not_a_real_module in Workflow._getmodules(
692+
[f]
693+
), f"inferred module dependencies from {fn}"
694+
695+
finally:
696+
sys.path.pop()
697+
del sys.modules["not_a_real_module"]
698+
699+
700+
def test_workflow_explicit_modules_byvalue(tmp_path):
701+
module_fn = tmp_path / "not_a_real_module.py"
702+
sys.path.append(str(tmp_path))
703+
704+
with open(module_fn, "w") as module_f:
705+
module_f.write("def identity(col):\n return col")
706+
707+
import not_a_real_module
708+
709+
wf = nvt.Workflow(["col_a"] >> nvt.ops.LambdaOp(not_a_real_module.identity))
710+
711+
wf.save(str(tmp_path / "identity-workflow"), modules_byvalue=[not_a_real_module])
712+
713+
del not_a_real_module
714+
del sys.modules["not_a_real_module"]
715+
os.unlink(str(tmp_path / "not_a_real_module.py"))
716+
717+
Workflow.load(str(tmp_path / "identity-workflow"))
718+
719+
720+
def test_workflow_auto_infer_modules_byvalue(tmp_path):
721+
module_fn = tmp_path / "not_a_real_module.py"
722+
sys.path.append(str(tmp_path))
723+
724+
with open(module_fn, "w") as module_f:
725+
module_f.write("def identity(col):\n return col")
726+
727+
import not_a_real_module
728+
729+
wf = nvt.Workflow(["col_a"] >> nvt.ops.LambdaOp(not_a_real_module.identity))
730+
731+
wf.save(str(tmp_path / "identity-workflow"), modules_byvalue="auto")
732+
733+
del not_a_real_module
734+
del sys.modules["not_a_real_module"]
735+
os.unlink(str(tmp_path / "not_a_real_module.py"))
736+
737+
Workflow.load(str(tmp_path / "identity-workflow"))

0 commit comments

Comments
 (0)