Skip to content

Commit c84d530

Browse files
authored
feat: merge postprocess and pool (#49)
1 parent 3b902c0 commit c84d530

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

encoderfile/src/inference/sentence_embedding.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub fn sentence_embedding<'a>(
3232

3333
let pooled_outputs = transform.pool(outputs, a_mask_arr)?;
3434

35-
let embeddings = postprocess(state.transform().postprocess(pooled_outputs)?, encodings);
35+
let embeddings = postprocess(pooled_outputs, encodings);
3636

3737
Ok(embeddings)
3838
}

encoderfile/src/transforms/engine.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ pub struct Transform {
99
#[allow(dead_code)]
1010
lua: Lua,
1111
postprocessor: Option<LuaFunction>,
12-
pooler: Option<LuaFunction>,
1312
}
1413

1514
impl Transform {
@@ -26,20 +25,14 @@ impl Transform {
2625
.get::<Option<LuaFunction>>("Postprocess")
2726
.map_err(|e| ApiError::LuaError(e.to_string()))?;
2827

29-
let pooler = lua
30-
.globals()
31-
.get::<Option<LuaFunction>>("Pool")
32-
.map_err(|e| ApiError::LuaError(e.to_string()))?;
33-
3428
Ok(Self {
3529
lua,
3630
postprocessor,
37-
pooler,
3831
})
3932
}
4033

4134
pub fn pool(&self, data: Array3<f32>, mask: Array2<f32>) -> Result<Array2<f32>, ApiError> {
42-
let func = match &self.pooler {
35+
let func = match &self.postprocessor {
4336
Some(p) => p,
4437
None => {
4538
let batch = data.len_of(Axis(0));
@@ -193,7 +186,7 @@ mod tests {
193186
fn test_successful_pool() {
194187
let engine = Transform::new(
195188
r##"
196-
function Pool(arr, mask)
189+
function Postprocess(arr, mask)
197190
-- sum along second axis (lol)
198191
return arr:sum_axis(2)
199192
end
@@ -213,7 +206,7 @@ mod tests {
213206
fn test_bad_dim_pool() {
214207
let engine = Transform::new(
215208
r##"
216-
function Pool(arr, mask)
209+
function Postprocess(arr, mask)
217210
return arr
218211
end
219212
"##,

0 commit comments

Comments
 (0)