@@ -61,6 +61,24 @@ enum EmbeddingsWrap {
6161 View ( Embeddings < VocabWrap , StorageViewWrap > ) ,
6262}
6363
64+ impl EmbeddingsWrap {
65+ fn storage ( & self ) -> & Storage {
66+ use EmbeddingsWrap :: * ;
67+ match self {
68+ NonView ( e) => e. storage ( ) ,
69+ View ( e) => e. storage ( ) ,
70+ }
71+ }
72+
73+ fn vocab ( & self ) -> & VocabWrap {
74+ use EmbeddingsWrap :: * ;
75+ match self {
76+ NonView ( e) => e. vocab ( ) ,
77+ View ( e) => e. vocab ( ) ,
78+ }
79+ }
80+ }
81+
6482/// finalfusion embeddings.
6583#[ pyclass( name=Embeddings ) ]
6684struct PyEmbeddings {
@@ -329,12 +347,9 @@ impl PyVocab {
329347 fn item_to_indices ( & self , key : String ) -> PyResult < PyObject > {
330348 let embeds = self . embeddings . borrow ( ) ;
331349
332- use EmbeddingsWrap :: * ;
333- let indices = match & * embeds {
334- View ( e) => e. vocab ( ) . idx ( key. as_str ( ) ) ,
335- NonView ( e) => e. vocab ( ) . idx ( key. as_str ( ) ) ,
336- } ;
337- indices
350+ embeds
351+ . vocab ( )
352+ . idx ( key. as_str ( ) )
338353 . map ( |idx| {
339354 let gil = pyo3:: Python :: acquire_gil ( ) ;
340355 match idx {
@@ -350,28 +365,17 @@ impl PyVocab {
350365impl PySequenceProtocol for PyVocab {
351366 fn __len__ ( & self ) -> PyResult < usize > {
352367 let embeds = self . embeddings . borrow ( ) ;
353-
354- use EmbeddingsWrap :: * ;
355- match & * embeds {
356- View ( e) => Ok ( e. vocab ( ) . words ( ) . len ( ) ) ,
357- NonView ( e) => Ok ( e. vocab ( ) . words ( ) . len ( ) ) ,
358- }
368+ Ok ( embeds. vocab ( ) . len ( ) )
359369 }
360370
361371 fn __getitem__ ( & self , idx : isize ) -> PyResult < String > {
362372 let embeds = self . embeddings . borrow ( ) ;
373+ let words = embeds. vocab ( ) . words ( ) ;
363374
364- use EmbeddingsWrap :: * ;
365-
366- let v = match & * embeds {
367- View ( e) => e. vocab ( ) ,
368- NonView ( e) => e. vocab ( ) ,
369- } ;
370-
371- if idx >= v. words ( ) . len ( ) as isize || idx < 0 {
375+ if idx >= words. len ( ) as isize || idx < 0 {
372376 Err ( exceptions:: IndexError :: py_err ( "list index out of range" ) )
373377 } else {
374- Ok ( v . words ( ) [ idx as usize ] . clone ( ) )
378+ Ok ( words[ idx as usize ] . clone ( ) )
375379 }
376380 }
377381}
@@ -408,20 +412,11 @@ impl PyIterProtocol for PyEmbeddingIterator {
408412 let slf = & mut * slf;
409413
410414 let embeddings = slf. embeddings . borrow ( ) ;
411-
412- use EmbeddingsWrap :: * ;
413- let vocab = match & * embeddings {
414- View ( e) => e. vocab ( ) ,
415- NonView ( e) => e. vocab ( ) ,
416- } ;
415+ let vocab = embeddings. vocab ( ) ;
417416
418417 if slf. idx < vocab. len ( ) {
419418 let word = vocab. words ( ) [ slf. idx ] . to_string ( ) ;
420-
421- let embed = match & * embeddings {
422- View ( e) => e. storage ( ) . embedding ( slf. idx ) ,
423- NonView ( e) => e. storage ( ) . embedding ( slf. idx ) ,
424- } ;
419+ let embed = embeddings. storage ( ) . embedding ( slf. idx ) ;
425420
426421 slf. idx += 1 ;
427422
0 commit comments