Skip to content

Commit e813d41

Browse files
committed
Tweak writing of all attributes, allow writing only configurable attributes
1 parent 4b91d49 commit e813d41

File tree

3 files changed

+137
-59
lines changed

3 files changed

+137
-59
lines changed

docs/attr_doc_gen.py

Lines changed: 102 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import hls4ml.model.attributes as attributes
55
import hls4ml.model.layers as layers
66

7-
all_backends = backends.get_available_backends()
8-
# Removing duplicates but preserving order
9-
all_layers = list(dict.fromkeys(layers.layer_map.values()))
10-
117

128
class AttrList:
139
def __init__(self, cls_name, cls_attrs) -> None:
@@ -17,29 +13,61 @@ def __init__(self, cls_name, cls_attrs) -> None:
1713
self.weight_attrs = [attr for attr in cls_attrs if attr.__class__.__name__ == 'WeightAttribute']
1814
self.base_attrs = [attr for attr in cls_attrs if attr not in self.config_attrs + self.type_attrs + self.weight_attrs]
1915
self.backend_attrs = {}
16+
self.reverse_backend_attrs = [] # Will hold (attr, backend_name) pairs, used temporarily
17+
self.unique_backend_attrs = []
2018

2119
def add_backend_attrs(self, backend_name, backend_attrs):
2220
self.backend_attrs[backend_name] = backend_attrs
2321

22+
for attr in backend_attrs:
23+
self.reverse_backend_attrs.append((attr, backend_name))
24+
25+
def sift_backend_attrs(self):
26+
grouped_dict = {}
27+
for attr, backend_name in self.reverse_backend_attrs:
28+
if attr not in grouped_dict:
29+
grouped_dict[attr] = []
30+
grouped_dict[attr].append(backend_name)
31+
32+
for attr, backend_names in grouped_dict.items():
33+
attr.available_in = backend_names
34+
self.unique_backend_attrs.append(attr)
35+
36+
@property
37+
def only_configurable(self):
38+
all_attrs = self.config_attrs + self.type_attrs + self.unique_backend_attrs
39+
return [attr for attr in all_attrs if attr.configurable is True]
40+
2441

25-
attr_map = []
42+
def convert_to_attr_list():
43+
all_backends = backends.get_available_backends()
44+
# Removing duplicates but preserving order
45+
all_layers = list(dict.fromkeys(layers.layer_map.values()))
46+
all_layers_attrs = []
2647

27-
for layer_cls in all_layers:
28-
base_attrs = layer_cls.expected_attributes
48+
for layer_cls in all_layers:
49+
base_attrs = layer_cls.expected_attributes
2950

30-
attr_list = AttrList(layer_cls.__name__, base_attrs)
51+
attr_list = AttrList(layer_cls.__name__, base_attrs)
3152

32-
for backend_name in all_backends:
33-
backend = backends.get_backend(backend_name)
53+
for backend_name in all_backends:
54+
backend = backends.get_backend(backend_name)
3455

35-
backend_cls = backend.create_layer_class(layer_cls)
36-
backend_attrs = backend_cls.expected_attributes
56+
backend_cls = backend.create_layer_class(layer_cls)
57+
backend_attrs = backend_cls.expected_attributes
3758

38-
diff_atts = [attr for attr in backend_attrs if attr not in base_attrs] # Sets are faster, but don't preserve order
39-
if len(diff_atts) > 0:
40-
attr_list.add_backend_attrs(backend.name, diff_atts)
59+
diff_atts = [
60+
attr for attr in backend_attrs if attr not in base_attrs
61+
] # Sets are faster, but don't preserve order
62+
if len(diff_atts) > 0:
63+
attr_list.add_backend_attrs(backend.name, diff_atts)
4164

42-
attr_map.append(attr_list)
65+
all_layers_attrs.append(attr_list)
66+
67+
for attr_list in all_layers_attrs:
68+
attr_list.sift_backend_attrs()
69+
70+
return all_layers_attrs
4371

4472

4573
def print_attrs(attrs, file):
@@ -60,40 +88,62 @@ def print_attrs(attrs, file):
6088
if attr.description is not None:
6189
file.write(' * ' + attr.description + '\n\n')
6290

91+
if hasattr(attr, 'available_in'):
92+
file.write(' * Available in: ' + ', '.join(attr.available_in) + '\n\n')
93+
94+
95+
def write_all_attributes(all_layers_attrs):
96+
with open('attributes.rst', mode='w') as file:
97+
file.write('================\n')
98+
file.write('Layer attributes\n')
99+
file.write('================\n\n\n')
100+
101+
for attr_list in all_layers_attrs:
102+
file.write(attr_list.cls_name + '\n')
103+
file.write('=' * len(attr_list.cls_name) + '\n')
104+
105+
if len(attr_list.base_attrs) > 0:
106+
file.write('Base attributes\n')
107+
file.write('---------------\n')
108+
print_attrs(attr_list.type_attrs, file)
109+
110+
if len(attr_list.type_attrs) > 0:
111+
file.write('Type attributes\n')
112+
file.write('---------------\n')
113+
print_attrs(attr_list.base_attrs, file)
114+
115+
if len(attr_list.weight_attrs) > 0:
116+
file.write('Weight attributes\n')
117+
file.write('-----------------\n')
118+
print_attrs(attr_list.weight_attrs, file)
119+
120+
if len(attr_list.config_attrs) > 0:
121+
file.write('Configurable attributes\n')
122+
file.write('-----------------------\n')
123+
print_attrs(attr_list.config_attrs, file)
124+
125+
if len(attr_list.backend_attrs) > 0:
126+
file.write('Backend-specific attributes\n')
127+
file.write('---------------------------\n')
128+
print_attrs(attr_list.unique_backend_attrs, file)
129+
130+
131+
def write_only_configurable(all_layers_attrs):
132+
with open('attributes.rst', mode='w') as file:
133+
file.write('================\n')
134+
file.write('Layer attributes\n')
135+
file.write('================\n\n\n')
136+
137+
for attr_list in all_layers_attrs:
138+
file.write(attr_list.cls_name + '\n')
139+
file.write('=' * len(attr_list.cls_name) + '\n')
140+
141+
config_attrs = attr_list.only_configurable
142+
if len(config_attrs) > 0:
143+
print_attrs(config_attrs, file)
144+
63145

64-
with open('attributes.rst', mode='w') as file:
65-
file.write('================\n')
66-
file.write('Layer attributes\n')
67-
file.write('================\n\n\n')
68-
69-
for attr_list in attr_map:
70-
file.write(attr_list.cls_name + '\n')
71-
file.write('=' * len(attr_list.cls_name) + '\n')
72-
73-
if len(attr_list.base_attrs) > 0:
74-
file.write('Base attributes\n')
75-
file.write('---------------\n')
76-
print_attrs(attr_list.type_attrs, file)
77-
78-
if len(attr_list.type_attrs) > 0:
79-
file.write('Type attributes\n')
80-
file.write('---------------\n')
81-
print_attrs(attr_list.base_attrs, file)
82-
83-
if len(attr_list.weight_attrs) > 0:
84-
file.write('Weight attributes\n')
85-
file.write('-----------------\n')
86-
print_attrs(attr_list.weight_attrs, file)
87-
88-
if len(attr_list.config_attrs) > 0:
89-
file.write('Configurable attributes\n')
90-
file.write('-----------------------\n')
91-
print_attrs(attr_list.config_attrs, file)
92-
93-
if len(attr_list.backend_attrs) > 0:
94-
file.write('Backend attributes\n')
95-
file.write('-----------------------\n')
96-
for backend, backend_attrs in attr_list.backend_attrs.items():
97-
file.write(backend + '\n')
98-
file.write('^' * len(backend) + '\n')
99-
print_attrs(backend_attrs, file)
146+
if __name__ == '__main__':
147+
all_layers_attrs = convert_to_attr_list()
148+
write_all_attributes(all_layers_attrs)
149+
# write_only_configurable(all_layers_attrs)

hls4ml/model/attributes.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ def config_name(self):
6060
"""
6161
return convert_to_pascal_case(self.name)
6262

63+
def __eq__(self, other: object) -> bool:
64+
if not isinstance(other, Attribute):
65+
return NotImplemented
66+
return (
67+
self.name == other.name
68+
and self.value_type == other.value_type
69+
and self.default == other.default
70+
and self.configurable == other.configurable
71+
and self.description == other.description
72+
)
73+
74+
def __hash__(self) -> int:
75+
return hash((self.name, self.value_type, self.default, self.configurable, self.description))
76+
6377

6478
class ConfigurableAttribute(Attribute):
6579
"""
@@ -69,7 +83,7 @@ class ConfigurableAttribute(Attribute):
6983
when defining the expected attributes of layer classes.
7084
"""
7185

72-
def __init__(self, name, value_type=int, default=None, description=None):
86+
def __init__(self, name, value_type=Integral, default=None, description=None):
7387
super().__init__(name, value_type, default, configurable=True, description=description)
7488

7589

@@ -101,6 +115,13 @@ def __init__(self, name, choices, default=None, configurable=True, description=N
101115
def validate_value(self, value):
102116
return value in self.choices
103117

118+
def __eq__(self, other: object) -> bool:
119+
base_eq = super().__eq__(other)
120+
return base_eq and hasattr(other, 'choices') and set(self.choices) == set(other.choices)
121+
122+
def __hash__(self) -> int:
123+
return super().__hash__() ^ hash(tuple(sorted(self.choices)))
124+
104125

105126
class WeightAttribute(Attribute):
106127
"""
@@ -117,9 +138,7 @@ class CodeAttrubute(Attribute):
117138
"""
118139

119140
def __init__(self, name, description=None):
120-
super(WeightAttribute, self).__init__(
121-
name, value_type=Source, default=None, configurable=False, description=description
122-
)
141+
super().__init__(name, value_type=Source, default=None, configurable=False, description=description)
123142

124143

125144
# endregion

hls4ml/model/types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,15 @@ def __init__(self, width, signed):
6464
self.width = width
6565
self.signed = signed
6666

67-
def __eq__(self, other):
67+
def __eq__(self, other: object) -> bool:
6868
eq = self.width == other.width
6969
eq = eq and self.signed == other.signed
7070

7171
return eq
7272

73+
def __hash__(self) -> int:
74+
return hash((self.width, self.signed))
75+
7376

7477
class IntegerPrecisionType(PrecisionType):
7578
"""Arbitrary precision integer data type.
@@ -89,12 +92,15 @@ def __str__(self):
8992
return typestring
9093

9194
# Does this need to make sure other is also an IntegerPrecisionType? I could see a match between Fixed and Integer
92-
def __eq__(self, other):
95+
def __eq__(self, other: object) -> bool:
9396
if isinstance(other, IntegerPrecisionType):
9497
return super().__eq__(other)
9598

9699
return False
97100

101+
def __hash__(self) -> int:
102+
return super().__hash__()
103+
98104
@property
99105
def integer(self):
100106
return self.width
@@ -186,7 +192,7 @@ def __str__(self):
186192
typestring = '{signed}fixed<{args}>'.format(signed='u' if not self.signed else '', args=args)
187193
return typestring
188194

189-
def __eq__(self, other):
195+
def __eq__(self, other: object) -> bool:
190196
if isinstance(other, FixedPrecisionType):
191197
eq = super().__eq__(other)
192198
eq = eq and self.integer == other.integer
@@ -197,6 +203,9 @@ def __eq__(self, other):
197203

198204
return False
199205

206+
def __hash__(self) -> int:
207+
return super().__hash__() ^ hash((self.integer, self.rounding_mode, self.saturation_mode, self.saturation_bits))
208+
200209

201210
class XnorPrecisionType(PrecisionType):
202211
"""

0 commit comments

Comments
 (0)