|
1 | 1 | from datetime import datetime |
2 | 2 |
|
3 | 3 | from sqlalchemy import Index |
| 4 | +from werkzeug.security import generate_password_hash, check_password_hash |
| 5 | +from email_validator import validate_email, EmailNotValidError |
| 6 | +from email_normalize import normalize |
4 | 7 |
|
5 | 8 | from app import db |
6 | 9 |
|
7 | 10 |
|
| 11 | +class User(db.Model): |
| 12 | + __tablename__ = 'user' |
| 13 | + id = db.Column(db.Integer, primary_key=True) |
| 14 | + email = db.Column(db.String(120), nullable=False) |
| 15 | + email_normalized = db.Column(db.String(120), nullable=False, unique=True) |
| 16 | + password_hash = db.Column(db.String(256), nullable=False) |
| 17 | + created_on = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) |
| 18 | + |
| 19 | + # Does not check for non-deliverable mails. Use check_deliverability or resolve for that which does DNS checks |
| 20 | + # For more stricter validation, use confirmation emails, or a third party API |
| 21 | + @staticmethod |
| 22 | + def _normalize_email(email): |
| 23 | + # Follows RFCs, allows aliases and only lowers the domain part |
| 24 | + validated = validate_email(email, check_deliverability=False) |
| 25 | + # Lowers the local part and normalizes, removes aliases for popular email providers (gmail, yahoo etc) |
| 26 | + normalized = normalize(validated.email, resolve=False) |
| 27 | + return normalized |
| 28 | + |
| 29 | + @staticmethod |
| 30 | + def get(email): |
| 31 | + try: |
| 32 | + email_normalized = User._normalize_email(email) |
| 33 | + except EmailNotValidError: |
| 34 | + return None |
| 35 | + return User.query.filter_by(email_normalized=email_normalized).scalar() |
| 36 | + |
| 37 | + def set_email(self, email): |
| 38 | + try: |
| 39 | + self.email_normalized = self._normalize_email(email) |
| 40 | + self.email = email |
| 41 | + except EmailNotValidError as e: |
| 42 | + raise e |
| 43 | + |
| 44 | + def set_password(self, password): |
| 45 | + # scrypt stores salt with the hash, which it uses to verify the password |
| 46 | + self.password_hash = generate_password_hash(password, method='scrypt') |
| 47 | + |
| 48 | + def check_password(self, password): |
| 49 | + return check_password_hash(self.password_hash, password) |
| 50 | + |
| 51 | + |
8 | 52 | # https://stackoverflow.com/questions/2190272/sql-many-to-many-table-primary-key |
9 | 53 | category_subcategory = db.Table("category_subcategory", |
10 | 54 | db.Column("category_id", db.Integer, db.ForeignKey("category.id", ondelete="CASCADE", onupdate="CASCADE"), primary_key=True), |
|
0 commit comments