Skip to content

Conversation

@jan-elastic
Copy link
Contributor

@jan-elastic jan-elastic commented Feb 18, 2025

When preparing for/trying to implement a fix for OOM when performing inference on an extremely large document with a semantic_text field, I was annoyed by the amount of code duplication in EmbeddingRequestChunker (e.g. the switch/case by embedding type, different result fields per type, and different update/merge methods, ...). I've added generics to avoid all this duplication.

This saves almost 400 lines of code, while making the structure more clear (at least to me).

Furthermore, the naming of classes was pretty inconsistent/confusing, and there were some unclear interfaces. I've unified most of that. Here's the new naming scheme:

  • EmbeddingResults (new interface) raw results of the inference service with generics. Also contains EmbeddingResults.Embedding and EmbeddingResults.Chunk.
  • SparseEmbeddingResults (similar to before, but with ChunkedInferenceEmbeddingSparse merged into it; took this structure as inspiration for the other classes) implementation of EmbeddingResults for sparse embeddings
  • TextEmbeddingByteResults (previously InferenceTextEmbeddingByteResults, InferenceByteEmbedding and ChunkedInferenceEmbeddingByte) implementation of EmbeddingResults for dense byte embeddings
  • TextEmbeddingFloatResults (previously InferenceTextEmbeddingFloatResults and ChunkedInferenceEmbeddingFloat) implementation of EmbeddingResults for dense float embeddings
  • TextEmbeddingBitResults (previously InferenceTextEmbeddingBitResults) implementation of EmbeddingResults for dense bit embeddings. Note that chunking is now supported for this as well without any added code, even though it wasn't supported previously.
  • ChunkedInferenceEmbedding final result class that implements ChunkedInference. This is eventually produced by the EmbeddingRequestChunker

In the process, the EmbeddingInt interface (which just contains a getSize method) got removed too.

Apologies for the high number of tests that got updated, but that's a consequence of the ~15-fold code duplication for all inference providers.

Regarding the relations of the classes/interfaces:

  • class TextEmbedding{Bit,Byte,Float}Results is a interface TextEmbeddingResults
  • class SparseEmbeddingResults and interface TextEmbeddingResults is a interface EmbeddingResults
  • interface EmbeddingResults is a interface InferenceServiceResults

Furthermore:

  • class TextEmbedding{Bit,Byte,Float}Results and class SparseEmbeddingResults contain a lists of Embeddings
  • these Embeddings can be transformed to Chunks
  • these Chunks can be packed in a class ChunkedInferenceEmbedding, which is a interface ChunkedInference

@jan-elastic jan-elastic marked this pull request as draft February 18, 2025 10:45
@jan-elastic jan-elastic force-pushed the refactor-EmbeddingRequestChunker branch from 73fe062 to 0ada0a9 Compare February 18, 2025 12:02
assertThat(batches, hasSize(1));
assertEquals(batches.get(0).batch().inputs(), inputs);
var subBatches = batches.get(0).batch().subBatches();
EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the only thing that has changed to the tests is asserting the requests, which are internal and not used outside of this class. (Which makes me wonder if this is a great test.)

@jan-elastic jan-elastic force-pushed the refactor-EmbeddingRequestChunker branch from 44cdcc4 to 93489dd Compare February 19, 2025 09:50
@github-actions
Copy link
Contributor

Warning

It looks like this PR modifies one or more .asciidoc files. These files are being migrated to Markdown, and any changes merged now will be lost. See the migration guide for details.

@jan-elastic jan-elastic force-pushed the refactor-EmbeddingRequestChunker branch 2 times, most recently from 49fc34a to 7726027 Compare February 19, 2025 10:25
@jan-elastic jan-elastic added :ml Machine learning Team:ML Meta label for the ML team labels Feb 19, 2025
@jan-elastic jan-elastic marked this pull request as ready for review February 19, 2025 13:20
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renames and switch case clean up look good! The addToBatches changes I'm not sure about, I'm not as familiar with this part of the code but I'll take another look.

switch (embeddingType) {
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to confirm but it seems like AtomicArray permits null values. I can't tell if the previous code was trying to retrieve them though 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AtomicArray (old code) is just a wrapper of AtomicReferenceArray (new code), of which the additional logic wasn't really used (except for setOnce, which is just set with an extra validation).

}
}

private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did all of this logic distill to the stream group-by call below? Or is this the primary changes to begin addressing the OOMs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this basically does the same with less code. This PR functionally doesn't change a thing.

It's just cleaning up, so that I can address the OOM issue in a clean kitchen (which I like).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did all of this logic distill to the stream group-by call below?

All this logic distilled to the stream group-by 🫢

Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Thanks for the clean up

AtomicInteger counter = new AtomicInteger();
this.batchRequests = requests.stream()
.flatMap(List::stream)
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, you've deleted hundreds of lines of code with this 😁

@jan-elastic jan-elastic force-pushed the refactor-EmbeddingRequestChunker branch from 7fdad46 to 099c35a Compare February 21, 2025 12:32
@jan-elastic jan-elastic merged commit 5f99708 into main Feb 21, 2025
18 checks passed
@jan-elastic jan-elastic deleted the refactor-EmbeddingRequestChunker branch February 21, 2025 14:29
jan-elastic added a commit that referenced this pull request Feb 25, 2025
* refactor

* inference generics

* more refactor

* unify naming

* remove interface "EmbeddingInt"

* more renaming

* javadoc

* revert accidental changeas

* remove ununsed EmbeddingRequestChunker.EmbeddingType

* polish

* support chunking for text embedding bits

* Polish error messagex

* fix VoyageAI conflicts
elasticsearchmachine pushed a commit that referenced this pull request Feb 25, 2025
* Refactor EmbeddingRequestChunker (#122818)

* refactor

* inference generics

* more refactor

* unify naming

* remove interface "EmbeddingInt"

* more renaming

* javadoc

* revert accidental changeas

* remove ununsed EmbeddingRequestChunker.EmbeddingType

* polish

* support chunking for text embedding bits

* Polish error messagex

* fix VoyageAI conflicts

* conflicts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

:ml Machine learning >non-issue Team:ML Meta label for the ML team v9.1.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants