Skip to content

Commit 7cc138f

Browse files
andylou2tfx-copybara
authored andcommitted
Added equi_join_any_indices_op.
PiperOrigin-RevId: 431565983
1 parent ef7bcd9 commit 7cc138f

14 files changed

+553
-137
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
## Major Features and Improvements
66

7+
* Added equi_join_any_indices_op.
8+
79
## Bug Fixes and Other Changes
810

911
* Depends on `tensorflow>=2.8.0,<2.9`.

struct2tensor/benchmarks/BUILD

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,39 @@ py_binary(
8888
],
8989
)
9090

91+
# py_test(
92+
# name = "ops_benchmark_test",
93+
# srcs = ["ops_benchmark.py"],
94+
# # Reduce the run time to fit on TAP.
95+
# # Follow the instructions in the file to properly run the benchmark.
96+
# args = ["--test_mode"],
97+
# main = "ops_benchmark.py",
98+
# python_version = "PY3",
99+
# # shard_count = 4,
100+
# deps = [
101+
# ":struct2tensor_benchmark_lib",
102+
# ],
103+
# )
104+
py_test(
105+
name = "ops_benchmark_test",
106+
srcs = ["ops_benchmark.py"],
107+
main = "ops_benchmark.py",
108+
deps = [
109+
":struct2tensor_benchmark_util",
110+
"//struct2tensor/ops:struct2tensor_ops",
111+
"@absl_py//absl/testing:parameterized",
112+
],
113+
)
114+
115+
py_binary(
116+
name = "ops_benchmark",
117+
srcs = ["ops_benchmark.py"],
118+
python_version = "PY3",
119+
deps = [
120+
":struct2tensor_benchmark_lib",
121+
],
122+
)
123+
91124
py_library(
92125
name = "struct2tensor_benchmark_util",
93126
srcs = ["struct2tensor_benchmark_util.py"],
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# pylint: disable=line-too-long
15+
r"""Benchmarks for struct2tensor.
16+
17+
18+
Usage:
19+
blaze run -c opt --dynamic_mode=off \
20+
--run_under='perflab \
21+
--constraints=arch=x86_64,platform_family=iota,platform_genus=sandybridge' \
22+
//struct2tensor/benchmarks:ops_benchmark \
23+
-- --notest_mode
24+
25+
Results:
26+
Num Iterations|Total (wall) Time (s)|Wall Time avg(ms)|Wall Time std|User CPU avg (ms)|User CPU std|System CPU avg (ms)|System CPU std
27+
equi_join_any_indices_monotonic_increasing_1000: 10000|12.046009018551558|1.2046009018551558|0.9495328669791991|1.2590000000000146|6.046119048963948|0.09000000000014552|5.595813730989087
28+
equi_join_any_indices_random_1000: 10000|12.026614569593221|1.2026614569593221|0.7186884686309326|1.2809999999997672|17.961518911690458|0.10100000000093132|10.87068154147507
29+
equi_join_indices_monotonic_increasing_1000: 10000|12.022087568882853|1.2022087568882853|0.6371426815230887|1.256000000000131|9.673456340434292|0.10300000000061119|13.443108707001388
30+
equi_join_indices_random_1000: 10000|12.04043987870682|1.2040439878706821|0.6657185129990686|1.2420000000001892|7.272474742950086|0.08600000000005821|5.508487482377499
31+
"""
32+
# pylint: disable=line-too-long
33+
34+
from absl.testing import parameterized
35+
from struct2tensor.benchmarks import struct2tensor_benchmark_util
36+
from struct2tensor.ops import struct2tensor_ops
37+
import tensorflow as tf
38+
39+
40+
class EquiJoinIndicesBenchmarks(struct2tensor_benchmark_util.OpsBenchmarks):
41+
"""Benchmarks for EquiJoinIndices."""
42+
43+
@parameterized.named_parameters(*[
44+
dict(
45+
testcase_name="equi_join_indices_monotonic_increasing",
46+
fn_name="equi_join_indices_monotonic_increasing",
47+
fn_args=[],
48+
data_key="monotonic_increasing",
49+
),
50+
dict(
51+
testcase_name="equi_join_indices_random",
52+
fn_name="equi_join_indices_random",
53+
fn_args=[],
54+
data_key="random",
55+
),
56+
])
57+
def test_equi_join_indices(self, fn_name, fn_args, data_key):
58+
59+
def benchmark_fn(session):
60+
a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,))
61+
b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,))
62+
result = struct2tensor_ops.equi_join_indices(a, b)
63+
with tf.control_dependencies(result):
64+
x = tf.constant(1)
65+
return session.make_callable(x, feed_list=[a, b])
66+
67+
self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key)
68+
69+
70+
class EquiJoinAnyIndicesBenchmarks(struct2tensor_benchmark_util.OpsBenchmarks):
71+
"""Benchmarks for EquiJoinAnyIndices."""
72+
73+
@parameterized.named_parameters(*[
74+
dict(
75+
testcase_name="equi_join_any_indices_monotonic_increasing",
76+
fn_name="equi_join_any_indices_monotonic_increasing",
77+
fn_args=[],
78+
data_key="monotonic_increasing",
79+
),
80+
dict(
81+
testcase_name="equi_join_any_indices_random",
82+
fn_name="equi_join_any_indices_random",
83+
fn_args=[],
84+
data_key="random",
85+
),
86+
])
87+
def test_equi_join_indices(self, fn_name, fn_args, data_key):
88+
89+
def benchmark_fn(session):
90+
a = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,))
91+
b = tf.compat.v1.placeholder(dtype=tf.int64, shape=(None,))
92+
result = struct2tensor_ops.equi_join_any_indices(a, b)
93+
with tf.control_dependencies(result):
94+
x = tf.constant(1)
95+
return session.make_callable(x, feed_list=[a, b])
96+
97+
self.run_benchmarks(fn_name, benchmark_fn, fn_args, data_key)
98+
99+
100+
if __name__ == "__main__":
101+
tf.test.main()

struct2tensor/benchmarks/struct2tensor_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import tensorflow as tf
3232

3333

34-
class ProjectBenchmarks(struct2tensor_benchmark_util.Struct2tensorBenchmarks):
34+
class ProjectBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks):
3535
"""Benchmarks for projecting fields."""
3636

3737
# pylint: disable=g-complex-comprehension
@@ -389,7 +389,7 @@ def test_project(self, fn_name, fn_args, proto_list_key):
389389
self.run_benchmarks(fn_name, _get_project_fn, fn_args, proto_list_key)
390390

391391

392-
class PromoteBenchmarks(struct2tensor_benchmark_util.Struct2tensorBenchmarks):
392+
class PromoteBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks):
393393
"""Benchmarks for promoting fields."""
394394

395395
@parameterized.named_parameters(*[
@@ -431,7 +431,7 @@ def test_promote(self, fn_name, fn_args, proto_list_key):
431431
self.run_benchmarks(fn_name, _get_promote_fn, fn_args, proto_list_key)
432432

433433

434-
class BroadcastBenchmarks(struct2tensor_benchmark_util.Struct2tensorBenchmarks):
434+
class BroadcastBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks):
435435
"""Benchmarks for broadcasting fields."""
436436

437437
@parameterized.named_parameters(*[
@@ -473,7 +473,7 @@ def test_broadcast(self, fn_name, fn_args, proto_list_key):
473473
self.run_benchmarks(fn_name, _get_broadcast_fn, fn_args, proto_list_key)
474474

475475

476-
class RerootBenchmarks(struct2tensor_benchmark_util.Struct2tensorBenchmarks):
476+
class RerootBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks):
477477
"""Benchmarks for rerooting fields."""
478478

479479
@parameterized.named_parameters(*[
@@ -518,7 +518,7 @@ def test_reroot(self, fn_name, fn_args, proto_list_key):
518518

519519

520520
class PrensorToTensorBenchmarks(
521-
struct2tensor_benchmark_util.Struct2tensorBenchmarks):
521+
struct2tensor_benchmark_util.ProtoDataBenchmarks):
522522
"""Benchmarks for converting prensor to tensors."""
523523

524524
# pylint: disable=g-complex-comprehension
@@ -825,7 +825,7 @@ def test_to_sparse(self, fn_name, fn_args, proto_list_key):
825825
proto_list_key)
826826

827827

828-
class TfExampleBenchmarks(struct2tensor_benchmark_util.Struct2tensorBenchmarks):
828+
class TfExampleBenchmarks(struct2tensor_benchmark_util.ProtoDataBenchmarks):
829829
"""Benchmarks for converting tf.example to tensors."""
830830

831831
# pylint: disable=g-complex-comprehension

0 commit comments

Comments
 (0)