Skip to content

Commit 0b68540

Browse files
feedback
1 parent 2378640 commit 0b68540

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

elasticsearch_dsl/aggs.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,6 +2788,39 @@ def __init__(
27882788
super().__init__(path=path, **kwargs)
27892789

27902790

2791+
class RandomSampler(Bucket[_R]):
2792+
"""
2793+
A single bucket aggregation that randomly includes documents in the
2794+
aggregated results. Sampling provides significant speed improvement at
2795+
the cost of accuracy.
2796+
2797+
:arg probability: (required) The probability that a document will be
2798+
included in the aggregated data. Must be greater than 0, less than
2799+
0.5, or exactly 1. The lower the probability, the fewer documents
2800+
are matched.
2801+
:arg seed: The seed to generate the random sampling of documents. When
2802+
a seed is provided, the random subset of documents is the same
2803+
between calls.
2804+
:arg shard_seed: When combined with seed, setting shard_seed ensures
2805+
100% consistent sampling over shards where data is exactly the
2806+
same.
2807+
"""
2808+
2809+
name = "random_sampler"
2810+
2811+
def __init__(
2812+
self,
2813+
*,
2814+
probability: Union[float, "DefaultType"] = DEFAULT,
2815+
seed: Union[int, "DefaultType"] = DEFAULT,
2816+
shard_seed: Union[int, "DefaultType"] = DEFAULT,
2817+
**kwargs: Any,
2818+
):
2819+
super().__init__(
2820+
probability=probability, seed=seed, shard_seed=shard_seed, **kwargs
2821+
)
2822+
2823+
27912824
class Sampler(Bucket[_R]):
27922825
"""
27932826
A filtering aggregation used to limit any sub aggregations' processing
@@ -3696,7 +3729,3 @@ def __init__(
36963729

36973730
def result(self, search: "SearchBase[_R]", data: Any) -> AttrDict[Any]:
36983731
return FieldBucketData(self, search, data)
3699-
3700-
3701-
class RandomSampler(Bucket[_R]):
3702-
name = "random_sampler"

tests/test_aggs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def test_filters_correctly_identifies_the_hash() -> None:
220220

221221

222222
def test_bucket_sort_agg() -> None:
223+
# test the dictionary (type ignored) and fully typed alterantives
223224
bucket_sort_agg = aggs.BucketSort(sort=[{"total_sales": {"order": "desc"}}], size=3) # type: ignore
224225
assert bucket_sort_agg.to_dict() == {
225226
"bucket_sort": {"sort": [{"total_sales": {"order": "desc"}}], "size": 3}
@@ -251,6 +252,7 @@ def test_bucket_sort_agg() -> None:
251252

252253

253254
def test_bucket_sort_agg_only_trnunc() -> None:
255+
# test the dictionary (type ignored) and fully typed alterantives
254256
bucket_sort_agg = aggs.BucketSort(**{"from": 1, "size": 1, "_expand__to_dot": False}) # type: ignore
255257
assert bucket_sort_agg.to_dict() == {"bucket_sort": {"from": 1, "size": 1}}
256258
bucket_sort_agg = aggs.BucketSort(from_=1, size=1, _expand__to_dot=False)
@@ -265,20 +267,23 @@ def test_bucket_sort_agg_only_trnunc() -> None:
265267

266268

267269
def test_geohash_grid_aggregation() -> None:
270+
# test the dictionary (type ignored) and fully typed alterantives
268271
a = aggs.GeohashGrid(**{"field": "centroid", "precision": 3}) # type: ignore
269272
assert {"geohash_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
270273
a = aggs.GeohashGrid(field="centroid", precision=3)
271274
assert {"geohash_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
272275

273276

274277
def test_geohex_grid_aggregation() -> None:
278+
# test the dictionary (type ignored) and fully typed alterantives
275279
a = aggs.GeohexGrid(**{"field": "centroid", "precision": 3}) # type: ignore
276280
assert {"geohex_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
277281
a = aggs.GeohexGrid(field="centroid", precision=3)
278282
assert {"geohex_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
279283

280284

281285
def test_geotile_grid_aggregation() -> None:
286+
# test the dictionary (type ignored) and fully typed alterantives
282287
a = aggs.GeotileGrid(**{"field": "centroid", "precision": 3}) # type: ignore
283288
assert {"geotile_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
284289
a = aggs.GeotileGrid(field="centroid", precision=3)
@@ -318,6 +323,7 @@ def test_variable_width_histogram_aggregation() -> None:
318323

319324

320325
def test_ip_prefix_aggregation() -> None:
326+
# test the dictionary (type ignored) and fully typed alterantives
321327
a = aggs.IPPrefix(**{"field": "ipv4", "prefix_length": 24}) # type: ignore
322328
assert {"ip_prefix": {"field": "ipv4", "prefix_length": 24}} == a.to_dict()
323329
a = aggs.IPPrefix(field="ipv4", prefix_length=24)
@@ -501,6 +507,7 @@ def test_adjancecy_matrix_aggregation() -> None:
501507

502508

503509
def test_top_metrics_aggregation() -> None:
510+
# test the dictionary (type ignored) and fully typed alterantives
504511
a = aggs.TopMetrics(metrics={"field": "m"}, sort={"s": "desc"}) # type: ignore
505512
assert {
506513
"top_metrics": {"metrics": {"field": "m"}, "sort": {"s": "desc"}}

utils/templates/aggs.py.tpl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,3 @@ class {{ k.name }}({{ k.parent if k.parent else parent }}[_R]):
318318

319319
{% endif %}
320320
{% endfor %}
321-
{# the following aggregation is in technical preview and does not exist in the specification #}
322-
class RandomSampler(Bucket[_R]):
323-
name = "random_sampler"

0 commit comments

Comments
 (0)