16
16
17
17
import pytensor
18
18
from pytensor .configdefaults import config
19
- from pytensor .graph .basic import Apply , NoParams , Variable
19
+ from pytensor .graph .basic import Apply , Variable
20
20
from pytensor .graph .utils import (
21
21
MetaObject ,
22
- MethodNotDefined ,
23
22
TestValueError ,
24
23
add_tag_trace ,
25
24
get_variable_trace_string ,
26
25
)
27
- from pytensor .link .c .params_type import Params , ParamsType
28
26
29
27
30
28
if TYPE_CHECKING :
37
35
ComputeMapType = dict [Variable , list [bool ]]
38
36
InputStorageType = list [StorageCellType ]
39
37
OutputStorageType = list [StorageCellType ]
40
- ParamsInputType = Optional [tuple [Any , ...]]
41
- PerformMethodType = Callable [
42
- [Apply , list [Any ], OutputStorageType , ParamsInputType ], None
43
- ]
38
+ PerformMethodType = Callable [[Apply , list [Any ], OutputStorageType ], None ]
44
39
BasicThunkType = Callable [[], None ]
45
40
ThunkCallableType = Callable [
46
41
[PerformMethodType , StorageMapType , ComputeMapType , Apply ], None
@@ -202,7 +197,6 @@ class Op(MetaObject):
202
197
203
198
itypes : Optional [Sequence ["Type" ]] = None
204
199
otypes : Optional [Sequence ["Type" ]] = None
205
- params_type : Optional [ParamsType ] = None
206
200
207
201
_output_type_depends_on_input_value = False
208
202
"""
@@ -426,7 +420,6 @@ def perform(
426
420
node : Apply ,
427
421
inputs : Sequence [Any ],
428
422
output_storage : OutputStorageType ,
429
- params : ParamsInputType = None ,
430
423
) -> None :
431
424
"""Calculate the function on the inputs and put the variables in the output storage.
432
425
@@ -442,8 +435,6 @@ def perform(
442
435
these lists). Each sub-list corresponds to value of each
443
436
`Variable` in :attr:`node.outputs`. The primary purpose of this method
444
437
is to set the values of these sub-lists.
445
- params
446
- A tuple containing the values of each entry in :attr:`Op.__props__`.
447
438
448
439
Notes
449
440
-----
@@ -481,22 +472,6 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
481
472
"""
482
473
return True
483
474
484
- def get_params (self , node : Apply ) -> Params :
485
- """Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
486
- if isinstance (self .params_type , ParamsType ):
487
- wrapper = self .params_type
488
- if not all (hasattr (self , field ) for field in wrapper .fields ):
489
- # Let's print missing attributes for debugging.
490
- not_found = tuple (
491
- field for field in wrapper .fields if not hasattr (self , field )
492
- )
493
- raise AttributeError (
494
- f"{ type (self ).__name__ } : missing attributes { not_found } for ParamsType."
495
- )
496
- # ParamsType.get_params() will apply filtering to attributes.
497
- return self .params_type .get_params (self )
498
- raise MethodNotDefined ("get_params" )
499
-
500
475
def prepare_node (
501
476
self ,
502
477
node : Apply ,
@@ -538,34 +513,12 @@ def make_py_thunk(
538
513
else :
539
514
p = node .op .perform
540
515
541
- params = node .run_params ()
542
-
543
- if params is NoParams :
544
- # default arguments are stored in the closure of `rval`
545
- @is_thunk_type
546
- def rval (
547
- p = p , i = node_input_storage , o = node_output_storage , n = node , params = None
548
- ):
549
- r = p (n , [x [0 ] for x in i ], o )
550
- for o in node .outputs :
551
- compute_map [o ][0 ] = True
552
- return r
553
-
554
- else :
555
- params_val = node .params_type .filter (params )
556
-
557
- @is_thunk_type
558
- def rval (
559
- p = p ,
560
- i = node_input_storage ,
561
- o = node_output_storage ,
562
- n = node ,
563
- params = params_val ,
564
- ):
565
- r = p (n , [x [0 ] for x in i ], o , params )
566
- for o in node .outputs :
567
- compute_map [o ][0 ] = True
568
- return r
516
+ @is_thunk_type
517
+ def rval (p = p , i = node_input_storage , o = node_output_storage , n = node ):
518
+ r = p (n , [x [0 ] for x in i ], o )
519
+ for o in node .outputs :
520
+ compute_map [o ][0 ] = True
521
+ return r
569
522
570
523
rval .inputs = node_input_storage
571
524
rval .outputs = node_output_storage
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
640
593
641
594
"""
642
595
643
- def perform (self , node , inputs , output_storage , params = None ):
596
+ def perform (self , node , inputs , output_storage ):
644
597
raise NotImplementedError ("No Python implementation is provided by this Op." )
645
598
646
599
0 commit comments