2929
3030.. code-block:: python
3131
32- params_type = ParamsType(attr1=TensorType('int32', shape=(None, None)), attr2=ScalarType('float64'))
32+ params_type = ParamsType(
33+ attr1=TensorType("int32", shape=(None, None)), attr2=ScalarType("float64")
34+ )
3335
3436If your op contains attributes ``attr1`` **and** ``attr2``, the default ``op.get_params()``
3537implementation will automatically try to look for it and generate an appropriate Params object.
@@ -77,38 +79,48 @@ def __init__(value_attr1, value_attr2):
7779 from pytensor.link.c.params_type import ParamsType
7880 from pytensor.link.c.type import EnumType, EnumList
7981
80- wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2', 'CONSTANT_3'),
81- enum2=EnumType(PI=3.14, EPSILON=0.001))
82+ wrapper = ParamsType(
83+ enum1=EnumList("CONSTANT_1", "CONSTANT_2", "CONSTANT_3"),
84+ enum2=EnumType(PI=3.14, EPSILON=0.001),
85+ )
8286
8387 # Each enum constant is available as a wrapper attribute:
84- print(wrapper.CONSTANT_1, wrapper.CONSTANT_2, wrapper.CONSTANT_3,
85- wrapper.PI, wrapper.EPSILON)
88+ print(
89+ wrapper.CONSTANT_1,
90+ wrapper.CONSTANT_2,
91+ wrapper.CONSTANT_3,
92+ wrapper.PI,
93+ wrapper.EPSILON,
94+ )
8695
8796 # For convenience, you can also look for a constant by name with
8897 # ``ParamsType.get_enum()`` method.
89- pi = wrapper.get_enum('PI' )
90- epsilon = wrapper.get_enum(' EPSILON' )
91- constant_2 = wrapper.get_enum(' CONSTANT_2' )
98+ pi = wrapper.get_enum("PI" )
99+ epsilon = wrapper.get_enum(" EPSILON" )
100+ constant_2 = wrapper.get_enum(" CONSTANT_2" )
92101 print(pi, epsilon, constant_2)
93102
94103This implies that a ParamsType cannot contain different enum types with common enum names::
95104
96105 # Following line will raise an error,
97106 # as there is a "CONSTANT_1" defined both in enum1 and enum2.
98- wrapper = ParamsType(enum1=EnumList('CONSTANT_1', 'CONSTANT_2'),
99- enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5))
107+ wrapper = ParamsType(
108+ enum1=EnumList("CONSTANT_1", "CONSTANT_2"),
109+ enum2=EnumType(CONSTANT_1=0, CONSTANT_3=5),
110+ )
100111
101112If your enum types contain constant aliases, you can retrieve them from ParamsType
102113with ``ParamsType.enum_from_alias(alias)`` method (see :class:`pytensor.link.c.type.EnumType`
103114for more info about enumeration aliases).
104115
105116.. code-block:: python
106117
107- wrapper = ParamsType(enum1=EnumList('A', ('B', 'beta'), 'C'),
108- enum2=EnumList(('D', 'delta'), 'E', 'F'))
118+ wrapper = ParamsType(
119+ enum1=EnumList("A", ("B", "beta"), "C"), enum2=EnumList(("D", "delta"), "E", "F")
120+ )
109121 b1 = wrapper.B
110- b2 = wrapper.get_enum('B' )
111- b3 = wrapper.enum_from_alias(' beta' )
122+ b2 = wrapper.get_enum("B" )
123+ b3 = wrapper.enum_from_alias(" beta" )
112124 assert b1 == b2 == b3
113125
114126"""
@@ -236,10 +248,13 @@ class Params(dict):
236248
237249 from pytensor.link.c.params_type import ParamsType, Params
238250 from pytensor.scalar import ScalarType
251+
239252 # You must create a ParamsType first:
240- params_type = ParamsType(attr1=ScalarType('int32'),
241- key2=ScalarType('float32'),
242- field3=ScalarType('int64'))
253+ params_type = ParamsType(
254+ attr1=ScalarType("int32"),
255+ key2=ScalarType("float32"),
256+ field3=ScalarType("int64"),
257+ )
243258 # Then you can create a Params object with
244259 # the params type defined above and values for attributes.
245260 params = Params(params_type, attr1=1, key2=2.0, field3=3)
@@ -491,11 +506,13 @@ def get_enum(self, key):
491506 from pytensor.link.c.type import EnumType, EnumList
492507 from pytensor.scalar import ScalarType
493508
494- wrapper = ParamsType(scalar=ScalarType('int32'),
495- letters=EnumType(A=1, B=2, C=3),
496- digits=EnumList('ZERO', 'ONE', 'TWO'))
497- print(wrapper.get_enum('C')) # 3
498- print(wrapper.get_enum('TWO')) # 2
509+ wrapper = ParamsType(
510+ scalar=ScalarType("int32"),
511+ letters=EnumType(A=1, B=2, C=3),
512+ digits=EnumList("ZERO", "ONE", "TWO"),
513+ )
514+ print(wrapper.get_enum("C")) # 3
515+ print(wrapper.get_enum("TWO")) # 2
499516
500517 # You can also directly do:
501518 print(wrapper.C)
@@ -520,17 +537,19 @@ def enum_from_alias(self, alias):
520537 from pytensor.link.c.type import EnumType, EnumList
521538 from pytensor.scalar import ScalarType
522539
523- wrapper = ParamsType(scalar=ScalarType('int32'),
524- letters=EnumType(A=(1, 'alpha'), B=(2, 'beta'), C=3),
525- digits=EnumList(('ZERO', 'nothing'), ('ONE', 'unit'), ('TWO', 'couple')))
526- print(wrapper.get_enum('C')) # 3
527- print(wrapper.get_enum('TWO')) # 2
528- print(wrapper.enum_from_alias('alpha')) # 1
529- print(wrapper.enum_from_alias('nothing')) # 0
540+ wrapper = ParamsType(
541+ scalar=ScalarType("int32"),
542+ letters=EnumType(A=(1, "alpha"), B=(2, "beta"), C=3),
543+ digits=EnumList(("ZERO", "nothing"), ("ONE", "unit"), ("TWO", "couple")),
544+ )
545+ print(wrapper.get_enum("C")) # 3
546+ print(wrapper.get_enum("TWO")) # 2
547+ print(wrapper.enum_from_alias("alpha")) # 1
548+ print(wrapper.enum_from_alias("nothing")) # 0
530549
531550 # For the following, alias 'C' is not defined, so the method looks for
532551 # a constant named 'C', and finds it.
533- print(wrapper.enum_from_alias('C')) # 3
552+ print(wrapper.enum_from_alias("C")) # 3
534553
535554 .. note::
536555
@@ -567,12 +586,14 @@ def get_params(self, *objects, **kwargs) -> Params:
567586 from pytensor.tensor.type import dmatrix
568587 from pytensor.scalar import ScalarType
569588
589+
570590 class MyObject:
571591 def __init__(self):
572592 self.a = 10
573593 self.b = numpy.asarray([[1, 2, 3], [4, 5, 6]])
574594
575- params_type = ParamsType(a=ScalarType('int32'), b=dmatrix, c=ScalarType('bool'))
595+
596+ params_type = ParamsType(a=ScalarType("int32"), b=dmatrix, c=ScalarType("bool"))
576597
577598 o = MyObject()
578599 value_for_c = False
0 commit comments