diff --git a/clickhouse_backend/backend/schema.py b/clickhouse_backend/backend/schema.py index bc30f44..a4c8555 100644 --- a/clickhouse_backend/backend/schema.py +++ b/clickhouse_backend/backend/schema.py @@ -254,6 +254,7 @@ def _model_extra_sql(self, model, engine): order_by = engine.order_by partition_by = engine.partition_by primary_key = engine.primary_key + sample_by = engine.sample_by if order_by is not None: yield "ORDER BY (%s)" % self._get_expression(model, *order_by) @@ -261,6 +262,8 @@ def _model_extra_sql(self, model, engine): yield "PARTITION BY (%s)" % self._get_expression(model, *partition_by) if primary_key is not None: yield "PRIMARY KEY (%s)" % self._get_expression(model, *primary_key) + if sample_by is not None: + yield "SAMPLE BY (%s)" % self._get_expression(model, *sample_by) if engine.settings: result = [] for setting, value in engine.settings.items(): diff --git a/clickhouse_backend/models/engines.py b/clickhouse_backend/models/engines.py index 67e992d..1055147 100644 --- a/clickhouse_backend/models/engines.py +++ b/clickhouse_backend/models/engines.py @@ -128,6 +128,7 @@ def __init__( order_by=None, partition_by=None, primary_key=None, + sample_by=None, **settings, ): assert ( @@ -136,8 +137,9 @@ def __init__( self.order_by = order_by self.primary_key = primary_key self.partition_by = partition_by + self.sample_by = sample_by - for key in ["order_by", "primary_key", "partition_by"]: + for key in ["order_by", "primary_key", "partition_by", "sample_by"]: value = getattr(self, key) if value is not None: if isinstance(value, str) or not isinstance(value, Iterable): diff --git a/clickhouse_backend/models/functions/datetime.py b/clickhouse_backend/models/functions/datetime.py index 637eca9..4019877 100644 --- a/clickhouse_backend/models/functions/datetime.py +++ b/clickhouse_backend/models/functions/datetime.py @@ -11,6 +11,7 @@ "toStartOfTenMinutes", "toStartOfFifteenMinutes", "toStartOfHour", + "toStartOfDay", "toYYYYMM", "toYYYYMMDD", "toYYYYMMDDhhmmss", @@ -80,6 +81,10 @@ class toStartOfHour(toStartOfMinute): pass +class toStartOfDay(toStartOfMinute): + pass + + class toYearWeek(Func): output_field = fields.UInt32Field() diff --git a/clickhouse_backend/models/query.py b/clickhouse_backend/models/query.py index d714ad6..ba5be59 100644 --- a/clickhouse_backend/models/query.py +++ b/clickhouse_backend/models/query.py @@ -31,6 +31,11 @@ def prewhere(self, *args, **kwargs): clone._query.add_prewhere(Q(*args, **kwargs)) return clone + def sample(self, sample_fraction, sample_offset=None): + clone = self._chain() + clone._query.add_sample(sample_fraction, sample_offset) + return clone + def datetimes(self, field_name, kind, order="ASC", tzinfo=None): """ Return a list of datetime objects representing all available diff --git a/clickhouse_backend/models/sql/compiler.py b/clickhouse_backend/models/sql/compiler.py index a0569b8..f862b5f 100644 --- a/clickhouse_backend/models/sql/compiler.py +++ b/clickhouse_backend/models/sql/compiler.py @@ -122,6 +122,8 @@ def as_sql(self, with_limits=True, with_col_aliases=False): refcounts_before = self.query.alias_refcount.copy() try: combinator = self.query.combinator + sample_fraction = self.query.sample_fraction + sample_offset = self.query.sample_offset if compat.dj_ge42: extra_select, order_by, group_by = self.pre_sql_setup( with_col_aliases=with_col_aliases or bool(combinator), @@ -203,6 +205,13 @@ def as_sql(self, with_limits=True, with_col_aliases=False): result += ["FROM", *from_] params.extend(f_params) + if sample_fraction: + if sample_offset: + sample_sql = "SAMPLE %s OFFSET %s" % (sample_fraction, sample_offset) + else: + sample_sql = "SAMPLE %s" % sample_fraction + result.append(sample_sql) + if prewhere: result.append("PREWHERE %s" % prewhere) params.extend(p_params) diff --git a/clickhouse_backend/models/sql/query.py b/clickhouse_backend/models/sql/query.py index 9f3abd3..4305f3f 100644 --- a/clickhouse_backend/models/sql/query.py +++ b/clickhouse_backend/models/sql/query.py @@ -19,6 +19,8 @@ def __init__(self, model, where=query.WhereNode, alias_cols=True): super().__init__(model, where, alias_cols) self.setting_info = {} self.prewhere = query.WhereNode() + self.sample_fraction = None + self.sample_offset = None def sql_with_params(self): """Choose the right db when database router is used.""" @@ -28,6 +30,8 @@ def clone(self): obj = super().clone() obj.setting_info = self.setting_info.copy() obj.prewhere = self.prewhere.clone() + obj.sample_fraction = self.sample_fraction + obj.sample_offset = self.sample_offset return obj def explain(self, using, format=None, type=None, **settings): @@ -36,6 +40,10 @@ def explain(self, using, format=None, type=None, **settings): compiler = q.get_compiler(using=using) return "\n".join(compiler.explain_query()) + def add_sample(self, sample_fraction, sample_offset): + self.sample_fraction = sample_fraction + self.sample_offset = sample_offset + def add_prewhere(self, q_object): """ refer add_q