Skip to content

Commit 33823e8

Browse files
authored
feat(python): Support @template.function("name") definitions (#7183)
1 parent 69e9437 commit 33823e8

File tree

3 files changed

+67
-9
lines changed

3 files changed

+67
-9
lines changed

packages/cubejs-backend-native/python/cube/src/__init__.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ class AttrRef:
140140
config: Configuration
141141
attribute: str
142142

143-
def __init__(self, config: Configuration, attribue: str):
143+
def __init__(self, config: Configuration, attribute: str):
144144
self.config = config
145-
self.attribute = attribue
145+
self.attribute = attribute
146146

147147
def __call__(self, func):
148148
if not callable(func):
@@ -163,21 +163,65 @@ class TemplateException(Exception):
163163
pass
164164

165165
class TemplateContext:
166-
def function(self, func):
166+
functions: dict[str, Callable]
167+
168+
def __init__(self):
169+
self.functions = {}
170+
171+
def add_function(self, name, func):
167172
if not callable(func):
168173
raise TemplateException("function registration must be used with functions, actual: '%s'" % type(func).__name__)
169-
170-
return context_func(func)
171174

172-
def filter(self, func):
175+
self.functions[name] = func
176+
177+
def add_filter(self, name, func):
173178
if not callable(func):
174179
raise TemplateException("function registration must be used with functions, actual: '%s'" % type(func).__name__)
175180

176181
raise TemplateException("filter registration is not supported")
177182

183+
def function(self, func):
184+
if isinstance(func, str):
185+
return TemplateFunctionRef(self, func)
186+
187+
self.add_function(func.__name__, func)
188+
return func
189+
190+
def filter(self, func):
191+
if isinstance(func, str):
192+
return TemplateFilterRef(self, func)
193+
194+
self.add_filter(func.__name__, func)
195+
return func
196+
178197
def variable(self, func):
179198
raise TemplateException("variable registration is not supported")
180199

200+
class TemplateFunctionRef:
201+
context: TemplateContext
202+
attribute: str
203+
204+
def __init__(self, context: TemplateContext, attribute: str):
205+
self.context = context
206+
self.attribute = attribute
207+
208+
def __call__(self, func):
209+
self.context.add_function(self.attribute, func)
210+
return func
211+
212+
213+
class TemplateFilterRef:
214+
context: TemplateContext
215+
attribute: str
216+
217+
def __init__(self, context: TemplateContext, attribute: str):
218+
self.context = context
219+
self.attribute = attribute
220+
221+
def __call__(self, func):
222+
self.context.add_filter(self.attribute, func)
223+
return func
224+
181225
def context_func(func):
182226
func.cube_context_func = True
183227
return func

packages/cubejs-backend-native/src/python/entry.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,21 @@ fn python_load_model(mut cx: FunctionContext) -> JsResult<JsPromise> {
7373
let model_module = PyModule::from_code(py, &model_content, &model_file_name, "")?;
7474
let mut collected_functions = CLReprObject::new();
7575

76-
if model_module.hasattr("__execution_context_locals")? {
76+
if model_module.hasattr("template")? {
77+
let functions = model_module
78+
.getattr("template")?
79+
.getattr("functions")?
80+
.downcast::<PyDict>()?;
81+
82+
for (local_key, local_value) in functions.iter() {
83+
if local_value.is_instance_of::<PyFunction>() {
84+
let fun: Py<PyFunction> = local_value.downcast::<PyFunction>()?.into();
85+
collected_functions
86+
.insert(local_key.to_string(), CLRepr::PyExternalFunction(fun));
87+
}
88+
}
89+
// TODO remove all other ways of defining functions
90+
} else if model_module.hasattr("__execution_context_locals")? {
7791
let execution_context_locals = model_module
7892
.getattr("__execution_context_locals")?
7993
.downcast::<PyDict>()?;

packages/cubejs-backend-native/test/templates/jinja-instance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
def arg_sum_integers(a, b):
77
return a + b
88

9-
@template.function
10-
def arg_bool(a):
9+
@template.function("arg_bool")
10+
def ab(a):
1111
return a + 0
1212

1313
@template.function

0 commit comments

Comments
 (0)