Skip to content

Commit ea17221

Browse files
committed
adding enum field
1 parent 9cf13ca commit ea17221

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

fields/fields.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import timedelta
22
from mongoengine.base import BaseField
3-
from mongoengine.fields import StringField, EmailField
3+
from mongoengine.fields import IntField, StringField, EmailField
44

55
import os
66
import datetime
@@ -10,7 +10,6 @@
1010
from django.core.files.base import File
1111
from django.core.files.storage import default_storage
1212

13-
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
1413
from django.utils.encoding import force_str, force_text
1514

1615

@@ -63,7 +62,7 @@ def __init__(self,
6362
upload_to='',
6463
storage=None,
6564
**kwargs):
66-
self.size=size
65+
self.size = size
6766
self.storage = storage or default_storage
6867
self.upload_to = upload_to
6968
if callable(upload_to):
@@ -73,7 +72,7 @@ def __init__(self,
7372
def __get__(self, instance, owner):
7473
if instance is None:
7574
return self
76-
75+
7776
file = instance._data.get(self.name)
7877

7978
if isinstance(file, str_types) or file is None:
@@ -101,13 +100,16 @@ def __set__(self, instance, value):
101100
instance._mark_as_changed(key)
102101

103102
def get_directory_name(self):
104-
return os.path.normpath(force_text(datetime.datetime.now().strftime(force_str(self.upload_to))))
103+
return os.path.normpath(force_text(
104+
datetime.datetime.now().strftime(force_str(self.upload_to))))
105105

106106
def get_filename(self, filename):
107-
return os.path.normpath(self.storage.get_valid_name(os.path.basename(filename)))
107+
return os.path.normpath(
108+
self.storage.get_valid_name(os.path.basename(filename)))
108109

109110
def generate_filename(self, instance, filename):
110-
return os.path.join(self.get_directory_name(), self.get_filename(filename))
111+
return os.path.join(
112+
self.get_directory_name(), self.get_filename(filename))
111113

112114
def to_mongo(self, value):
113115
if isinstance(value, self.proxy_class):
@@ -137,3 +139,43 @@ def validate(self, value):
137139
self.error('Invalid Mail-address: %s' % value)
138140
super(LowerEmailField, self).validate(value)
139141

142+
143+
class EnumField(object):
144+
"""
145+
A class to register Enum type into mongo
146+
147+
:param choices: must be of enum type and will be used as possible choices
148+
"""
149+
150+
def __init__(self, enum, *args, **kwargs):
151+
self.enum = enum
152+
kwargs['choices'] = [choice for choice in enum]
153+
super(EnumField, self).__init__(*args, **kwargs)
154+
155+
def __get_value(self, enum):
156+
return enum.value if hasattr(enum, 'value') else enum
157+
158+
def to_python(self, value):
159+
return self.enum(super(EnumField, self).to_python(value))
160+
161+
def to_mongo(self, value):
162+
return self.__get_value(value)
163+
164+
def prepare_query_value(self, op, value):
165+
return super(EnumField, self).prepare_query_value(
166+
op, self.__get_value(value))
167+
168+
def validate(self, value):
169+
return super(EnumField, self).validate(self.__get_value(value))
170+
171+
def _validate(self, value, **kwargs):
172+
return super(EnumField, self)._validate(
173+
self.enum(self.__get_value(value)), **kwargs)
174+
175+
176+
class IntEnumField(EnumField, IntField):
177+
pass
178+
179+
180+
class StringEnumField(EnumField, StringField):
181+
pass

0 commit comments

Comments
 (0)