Skip to content

Commit eaa184b

Browse files
author
nagytech
authored
Merge branch 'master' into feature/optgroup
2 parents 1c267f4 + 08a9ad8 commit eaa184b

File tree

3 files changed

+53
-19
lines changed

3 files changed

+53
-19
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ packages = wtforms_sqlalchemy
3434
include_package_data = true
3535
python_requires = >= 3.8
3636
install_requires =
37-
WTForms>=1.0.5,<3.1
37+
WTForms>=3.1
3838
SQLAlchemy>=0.7.10,<2
3939

4040
[flake8]

tests/test_main.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
class LazySelect:
3535
def __call__(self, field, **kwargs):
3636
return list(
37-
(val, str(label), selected) for val, label, selected in field.iter_choices()
37+
(val, str(label), selected, render_kw)
38+
for val, label, selected, render_kw in field.iter_choices()
3839
)
3940

4041

@@ -111,7 +112,9 @@ class F(Form):
111112
form.a.query = sess.query(self.Test)
112113
self.assertTrue(form.a.data is not None)
113114
self.assertEqual(form.a.data.id, 1)
114-
self.assertEqual(form.a(), [("1", "apple", True), ("2", "banana", False)])
115+
self.assertEqual(
116+
form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})]
117+
)
115118
self.assertTrue(form.validate())
116119

117120
form = F(a=sess.query(self.Test).filter_by(name="banana").first())
@@ -144,28 +147,32 @@ class F(Form):
144147

145148
form = F()
146149
self.assertEqual(form.a.data, None)
147-
self.assertEqual(form.a(), [("1", "apple", False), ("2", "banana", False)])
150+
self.assertEqual(
151+
form.a(), [("1", "apple", False, {}), ("2", "banana", False, {})]
152+
)
148153
self.assertEqual(form.b.data, None)
149154
self.assertEqual(
150155
form.b(),
151156
[
152-
("__None", "", True),
153-
("hello1", "apple", False),
154-
("hello2", "banana", False),
157+
("__None", "", True, {}),
158+
("hello1", "apple", False, {}),
159+
("hello2", "banana", False, {}),
155160
],
156161
)
157162
self.assertFalse(form.validate())
158163

159164
form = F(DummyPostData(a=["1"], b=["hello2"]))
160165
self.assertEqual(form.a.data.id, 1)
161-
self.assertEqual(form.a(), [("1", "apple", True), ("2", "banana", False)])
166+
self.assertEqual(
167+
form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})]
168+
)
162169
self.assertEqual(form.b.data.baz, "banana")
163170
self.assertEqual(
164171
form.b(),
165172
[
166-
("__None", "", False),
167-
("hello1", "apple", False),
168-
("hello2", "banana", True),
173+
("__None", "", False, {}),
174+
("hello1", "apple", False, {}),
175+
("hello2", "banana", True, {}),
169176
],
170177
)
171178
self.assertTrue(form.validate())
@@ -174,11 +181,17 @@ class F(Form):
174181
sess.add(self.Test(id=3, name="meh"))
175182
sess.flush()
176183
sess.commit()
177-
self.assertEqual(form.a(), [("1", "apple", True), ("2", "banana", False)])
184+
self.assertEqual(
185+
form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})]
186+
)
178187
form.a._object_list = None
179188
self.assertEqual(
180189
form.a(),
181-
[("1", "apple", True), ("2", "banana", False), ("3", "meh", False)],
190+
[
191+
("1", "apple", True, {}),
192+
("2", "banana", False, {}),
193+
("3", "meh", False, {}),
194+
],
182195
)
183196

184197
# Test bad data
@@ -217,14 +230,18 @@ def test_single_value_without_factory(self):
217230
form = self.F(DummyPostData(a=["1"]))
218231
form.a.query = self.sess.query(self.Test)
219232
self.assertEqual([1], [v.id for v in form.a.data])
220-
self.assertEqual(form.a(), [("1", "apple", True), ("2", "banana", False)])
233+
self.assertEqual(
234+
form.a(), [("1", "apple", True, {}), ("2", "banana", False, {})]
235+
)
221236
self.assertTrue(form.validate())
222237

223238
def test_multiple_values_without_query_factory(self):
224239
form = self.F(DummyPostData(a=["1", "2"]))
225240
form.a.query = self.sess.query(self.Test)
226241
self.assertEqual([1, 2], [v.id for v in form.a.data])
227-
self.assertEqual(form.a(), [("1", "apple", True), ("2", "banana", True)])
242+
self.assertEqual(
243+
form.a(), [("1", "apple", True, {}), ("2", "banana", True, {})]
244+
)
228245
self.assertTrue(form.validate())
229246

230247
form = self.F(DummyPostData(a=["1", "3"]))
@@ -245,7 +262,9 @@ class F(Form):
245262

246263
form = F()
247264
self.assertEqual([v.id for v in form.a.data], [2])
248-
self.assertEqual(form.a(), [("1", "apple", False), ("2", "banana", True)])
265+
self.assertEqual(
266+
form.a(), [("1", "apple", False, {}), ("2", "banana", True, {})]
267+
)
249268
self.assertTrue(form.validate())
250269

251270

wtforms_sqlalchemy/fields.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class QuerySelectField(SelectFieldBase):
4949
model instance and expected to return the label text. Otherwise, the model
5050
object's `__str__` will be used.
5151
52+
5253
Specify `get_group` to allow `option` elements to be grouped into `optgroup`
5354
sections. If a string, this is the name of an attribute on the model
5455
containing the group name. If a one-argument callable, this callable will
@@ -57,6 +58,12 @@ class QuerySelectField(SelectFieldBase):
5758
will be used as both the grouping key and the display label in the `select`
5859
options.
5960
61+
Specify `get_render_kw` to apply HTML attributes to each option. If a
62+
string, this is the name of an attribute on the model containing a
63+
dictionary. If a one-argument callable, this callable will be passed the
64+
model instance and expected to return a dictionary. Otherwise, an empty
65+
dictionary will be used.
66+
6067
If `allow_blank` is set to `True`, then a blank choice will be added to the
6168
top of the list. Selecting this choice will result in the `data` property
6269
being `None`. The label for this blank choice can be set by specifying the
@@ -73,6 +80,7 @@ def __init__(
7380
get_pk=None,
7481
get_label=None,
7582
get_group=None,
83+
get_render_kw=None,
7684
allow_blank=False,
7785
blank_text="",
7886
**kwargs
@@ -105,6 +113,13 @@ def __init__(
105113
else:
106114
self.get_group = get_group
107115

116+
if get_render_kw is None:
117+
self.get_render_kw = lambda _: {}
118+
elif isinstance(get_render_kw, str):
119+
self.get_render_kw = operator.attrgetter(get_render_kw)
120+
else:
121+
self.get_render_kw = get_render_kw
122+
108123
self.allow_blank = allow_blank
109124
self.blank_text = blank_text
110125
self.query = None
@@ -133,10 +148,10 @@ def _get_object_list(self):
133148

134149
def iter_choices(self):
135150
if self.allow_blank:
136-
yield ("__None", self.blank_text, self.data is None)
151+
yield ("__None", self.blank_text, self.data is None, {})
137152

138153
for pk, obj in self._get_object_list():
139-
yield (pk, self.get_label(obj), obj == self.data)
154+
yield (pk, self.get_label(obj), obj == self.data, self.get_render_kw(obj))
140155

141156
def has_groups(self):
142157
return self._has_groups
@@ -225,7 +240,7 @@ def _set_data(self, data):
225240

226241
def iter_choices(self):
227242
for pk, obj in self._get_object_list():
228-
yield (pk, self.get_label(obj), obj in self.data)
243+
yield (pk, self.get_label(obj), obj in self.data, self.get_render_kw(obj))
229244

230245
def process_formdata(self, valuelist):
231246
self._formdata = set(valuelist)

0 commit comments

Comments
 (0)