Skip to content

Commit ab0aea8

Browse files
sdks/python: reuse chunk_approximately_equals function across RAG
1 parent 466d533 commit ab0aea8

File tree

3 files changed

+3
-39
lines changed

3 files changed

+3
-39
lines changed

sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from apache_beam.testing.test_pipeline import TestPipeline
3232
from apache_beam.testing.util import assert_that
3333
from apache_beam.testing.util import equal_to
34+
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
3435

3536
# pylint: disable=unused-import
3637
try:
@@ -40,19 +41,6 @@
4041
SENTENCE_TRANSFORMERS_AVAILABLE = False
4142

4243

43-
def chunk_approximately_equals(expected, actual):
44-
"""Compare embeddings allowing for numerical differences."""
45-
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
46-
return False
47-
48-
return (
49-
expected.id == actual.id and expected.metadata == actual.metadata and
50-
expected.content == actual.content and
51-
len(expected.embedding.dense_embedding) == len(
52-
actual.embedding.dense_embedding) and
53-
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
54-
55-
5644
@pytest.mark.uses_transformers
5745
@unittest.skipIf(
5846
not SENTENCE_TRANSFORMERS_AVAILABLE, "sentence-transformers not available")

sdks/python/apache_beam/ml/rag/embeddings/open_ai_test.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
1716
import os
1817
import shutil
1918
import tempfile
@@ -28,18 +27,7 @@
2827
from apache_beam.testing.util import assert_that
2928
from apache_beam.testing.util import equal_to
3029
from apache_beam.ml.rag.embeddings.open_ai import OpenAITextEmbeddings
31-
32-
def chunk_approximately_equals(expected, actual):
33-
"""Compare embeddings allowing for numerical differences."""
34-
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
35-
return False
36-
37-
return (
38-
expected.id == actual.id and expected.metadata == actual.metadata and
39-
expected.content == actual.content and
40-
len(expected.embedding.dense_embedding) == len(
41-
actual.embedding.dense_embedding) and
42-
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
30+
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
4331

4432

4533
class OpenAITextEmbeddingsTest(unittest.TestCase):

sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from apache_beam.testing.test_pipeline import TestPipeline
2929
from apache_beam.testing.util import assert_that
3030
from apache_beam.testing.util import equal_to
31+
from apache_beam.ml.rag.embeddings.test_utils import chunk_approximately_equals
3132

3233
# pylint: disable=ungrouped-imports
3334
try:
@@ -38,19 +39,6 @@
3839
VERTEX_AI_AVAILABLE = False
3940

4041

41-
def chunk_approximately_equals(expected, actual):
42-
"""Compare embeddings allowing for numerical differences."""
43-
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
44-
return False
45-
46-
return (
47-
expected.id == actual.id and expected.metadata == actual.metadata and
48-
expected.content == actual.content and
49-
len(expected.embedding.dense_embedding) == len(
50-
actual.embedding.dense_embedding) and
51-
all(isinstance(x, float) for x in actual.embedding.dense_embedding))
52-
53-
5442
@unittest.skipIf(
5543
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
5644
class VertexAITextEmbeddingsTest(unittest.TestCase):

0 commit comments

Comments
 (0)