-
Notifications
You must be signed in to change notification settings - Fork 145
Inplace Composite and ScalarLoop Ops with multiple outputs #1322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
ae1177e
a12acfd
789b509
5fc2cb8
d03dc8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ def __init__( | |
constant: Sequence[Variable] | None = None, | ||
until: Variable | None = None, | ||
name="ScalarLoop", | ||
**kwargs, | ||
): | ||
if constant is None: | ||
constant = [] | ||
|
@@ -75,7 +76,7 @@ def __init__( | |
self.nout = len(self.outputs) | ||
self.name = name | ||
|
||
super().__init__() | ||
super().__init__(**kwargs) | ||
|
||
def output_types(self, input_types): | ||
return self.outputs_type | ||
|
@@ -115,7 +116,7 @@ def fgraph(self): | |
self._fgraph = fgraph | ||
return self._fgraph | ||
|
||
def clone(self): | ||
def clone(self, name=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unrelated but I just checked and the signature of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I don't think it's standardized, unlike node.clone |
||
if self.is_while: | ||
*update, until = self.outputs | ||
else: | ||
|
@@ -127,28 +128,16 @@ def clone(self): | |
update=update, | ||
constant=constant, | ||
until=until, | ||
name=self.name, | ||
name=self.name if name is None else name, | ||
**kwargs, | ||
) | ||
|
||
@property | ||
def fn(self): | ||
raise NotImplementedError | ||
|
||
def make_new_inplace(self, output_types_preference=None, name=None): | ||
""" | ||
This op.__init__ fct don't have the same parameter as other scalar op. | ||
This break the insert_inplace_optimizer optimization. | ||
This fct allow fix patch this. | ||
|
||
""" | ||
d = {k: getattr(self, k) for k in self.init_param} | ||
out = self.__class__(**d) | ||
if name: | ||
out.name = name | ||
else: | ||
name = out.name | ||
super(ScalarLoop, out).__init__(output_types_preference, name) | ||
return out | ||
return self.clone(output_types_preference=output_types_preference, name=name) | ||
|
||
def make_node(self, n_steps, *inputs): | ||
assert len(inputs) == self.nin - 1 | ||
|
@@ -229,11 +218,11 @@ def c_code_template(self): | |
c: f"%(i{int(i)})s" | ||
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1) | ||
} | ||
update_subd = { | ||
out_subd = { | ||
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update]) | ||
} | ||
until_subd = {u: "until" for u in fgraph.outputs[n_update:]} | ||
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd} | ||
subd = {**carry_subd, **constant_subd, **until_subd} | ||
|
||
for var in fgraph.variables: | ||
if var.owner is None: | ||
|
@@ -257,11 +246,11 @@ def c_code_template(self): | |
_c_code += "bool until = 1;\n\n" | ||
|
||
# Copy carried inputs | ||
for i, (var, name) in enumerate(carry_subd.items()): | ||
copy_var_name = f"{name}_copy{i}" | ||
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n" | ||
carry_subd[var] = copy_var_name | ||
subd[var] = copy_var_name | ||
for i, (var, name) in enumerate(carry_subd.items(), start=1): | ||
carry_var_name = f"{name}_carry{i}" | ||
_c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n" | ||
carry_subd[var] = carry_var_name | ||
subd[var] = carry_var_name | ||
|
||
# _c_code += 'printf("inputs=[");' | ||
# for i in range(1, len(fgraph.inputs)): | ||
|
@@ -270,9 +259,8 @@ def c_code_template(self): | |
|
||
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n" | ||
|
||
self.nodenames = [ | ||
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort()) | ||
] | ||
# Used by self.c_support_code_apply | ||
self.nodenames = nodenames = [] | ||
|
||
i = 0 | ||
for j, node in enumerate(fgraph.toposort()): | ||
|
@@ -282,9 +270,13 @@ def c_code_template(self): | |
name = f"V%(id)s_tmp{int(i)}" | ||
subd[output] = name | ||
_c_code += f"{output.type.dtype_specs()[1]} {name};\n" | ||
|
||
nodename = f"%(nodename)s_subnode{int(j)}" | ||
nodenames.append(nodename) | ||
|
||
s = node.op.c_code( | ||
node, | ||
self.nodenames[j], | ||
nodename, | ||
# Any node that depended on `init` will depend on `update` instead | ||
# The initial value of `update` was set to `init` before the loop | ||
[subd[input] for input in node.inputs], | ||
|
@@ -294,10 +286,12 @@ def c_code_template(self): | |
_c_code += s | ||
_c_code += "\n" | ||
|
||
# Set the carry variables to the output variables | ||
# Update the carry variables to the output variables | ||
_c_code += "\n" | ||
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True): | ||
_c_code += f"{init} = {update};\n" | ||
for carry, out in zip( | ||
carry_subd.values(), fgraph.outputs[:n_update], strict=True | ||
): | ||
_c_code += f"{carry} = {subd[out]};\n" | ||
|
||
# _c_code += 'printf("%%ld\\n", i);\n' | ||
# for carry in range(1, 10): | ||
|
@@ -309,6 +303,10 @@ def c_code_template(self): | |
# End of the loop | ||
_c_code += "}\n" | ||
|
||
# Assign the carry variables to the outputs | ||
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True): | ||
_c_code += f"{out} = {carry};\n" | ||
|
||
# Output until flag | ||
if self.is_while: | ||
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n" | ||
|
@@ -343,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub): | |
return res | ||
|
||
def c_code_cache_version_outer(self): | ||
return (3,) | ||
return (4,) |
Uh oh!
There was an error while loading. Please reload this page.