-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Refactor EmbeddingRequestChunker #122818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor EmbeddingRequestChunker #122818
Conversation
73fe062 to
0ada0a9
Compare
| assertThat(batches, hasSize(1)); | ||
| assertEquals(batches.get(0).batch().inputs(), inputs); | ||
| var subBatches = batches.get(0).batch().subBatches(); | ||
| EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: the only thing that has changed to the tests is asserting the requests, which are internal and not used outside of this class. (Which makes me wonder if this is a great test.)
44cdcc4 to
93489dd
Compare
|
Warning It looks like this PR modifies one or more |
49fc34a to
7726027
Compare
|
Pinging @elastic/ml-core (Team:ML) |
jonathan-buttner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renames and switch case clean up look good! The addToBatches changes I'm not sure about, I'm not as familiar with this part of the code but I'll take another look.
| switch (embeddingType) { | ||
| case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches)); | ||
| case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches)); | ||
| case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to confirm but it seems like AtomicArray permits null values. I can't tell if the previous code was trying to retrieve them though 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AtomicArray (old code) is just a wrapper of AtomicReferenceArray (new code), of which the additional logic wasn't really used (except for setOnce, which is just set with an extra validation).
| } | ||
| } | ||
|
|
||
| private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did all of this logic distill to the stream group-by call below? Or is this the primary changes to begin addressing the OOMs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this basically does the same with less code. This PR functionally doesn't change a thing.
It's just cleaning up, so that I can address the OOM issue in a clean kitchen (which I like).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did all of this logic distill to the stream group-by call below?
All this logic distilled to the stream group-by 🫢
davidkyle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for the clean up
| AtomicInteger counter = new AtomicInteger(); | ||
| this.batchRequests = requests.stream() | ||
| .flatMap(List::stream) | ||
| .collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, you've deleted hundreds of lines of code with this 😁
7fdad46 to
099c35a
Compare
* refactor * inference generics * more refactor * unify naming * remove interface "EmbeddingInt" * more renaming * javadoc * revert accidental changeas * remove ununsed EmbeddingRequestChunker.EmbeddingType * polish * support chunking for text embedding bits * Polish error messagex * fix VoyageAI conflicts
* Refactor EmbeddingRequestChunker (#122818) * refactor * inference generics * more refactor * unify naming * remove interface "EmbeddingInt" * more renaming * javadoc * revert accidental changeas * remove ununsed EmbeddingRequestChunker.EmbeddingType * polish * support chunking for text embedding bits * Polish error messagex * fix VoyageAI conflicts * conflicts
When preparing for/trying to implement a fix for OOM when performing inference on an extremely large document with a semantic_text field, I was annoyed by the amount of code duplication in
EmbeddingRequestChunker(e.g. theswitch/caseby embedding type, different result fields per type, and different update/merge methods, ...). I've added generics to avoid all this duplication.This saves almost 400 lines of code, while making the structure more clear (at least to me).
Furthermore, the naming of classes was pretty inconsistent/confusing, and there were some unclear interfaces. I've unified most of that. Here's the new naming scheme:
EmbeddingResults(new interface) raw results of the inference service with generics. Also containsEmbeddingResults.EmbeddingandEmbeddingResults.Chunk.SparseEmbeddingResults(similar to before, but withChunkedInferenceEmbeddingSparsemerged into it; took this structure as inspiration for the other classes) implementation ofEmbeddingResultsfor sparse embeddingsTextEmbeddingByteResults(previouslyInferenceTextEmbeddingByteResults,InferenceByteEmbeddingandChunkedInferenceEmbeddingByte) implementation ofEmbeddingResultsfor dense byte embeddingsTextEmbeddingFloatResults(previouslyInferenceTextEmbeddingFloatResultsandChunkedInferenceEmbeddingFloat) implementation ofEmbeddingResultsfor dense float embeddingsTextEmbeddingBitResults(previouslyInferenceTextEmbeddingBitResults) implementation ofEmbeddingResultsfor dense bit embeddings. Note that chunking is now supported for this as well without any added code, even though it wasn't supported previously.ChunkedInferenceEmbeddingfinal result class that implementsChunkedInference. This is eventually produced by theEmbeddingRequestChunkerIn the process, the
EmbeddingIntinterface (which just contains agetSizemethod) got removed too.Apologies for the high number of tests that got updated, but that's a consequence of the ~15-fold code duplication for all inference providers.
Regarding the relations of the classes/interfaces:
class TextEmbedding{Bit,Byte,Float}Resultsis ainterface TextEmbeddingResultsclass SparseEmbeddingResultsandinterface TextEmbeddingResultsis ainterface EmbeddingResultsinterface EmbeddingResultsis ainterface InferenceServiceResultsFurthermore:
class TextEmbedding{Bit,Byte,Float}Resultsandclass SparseEmbeddingResultscontain a lists ofEmbeddingsEmbeddings can be transformed toChunksChunks can be packed in aclass ChunkedInferenceEmbedding, which is ainterface ChunkedInference