@@ -809,155 +809,183 @@ protected void doExecute(Task task,
809
809
final ClusteringActionRequest clusteringRequest ,
810
810
final ActionListener <ClusteringActionResponse > listener ) {
811
811
final long tsSearchStart = System .nanoTime ();
812
- searchAction .execute (clusteringRequest .getSearchRequest (), new ActionListener <SearchResponse >() {
812
+ searchAction .execute (
813
+ clusteringRequest .getSearchRequest (),
814
+ new ActionListener <SearchResponse >() {
813
815
@ Override
814
816
public void onFailure (Exception e ) {
815
- listener .onFailure (e );
817
+ listener .onFailure (e );
816
818
}
817
819
818
820
@ Override
819
821
public void onResponse (SearchResponse response ) {
820
- final long tsSearchEnd = System .nanoTime ();
821
-
822
- LinkedHashMap <String , ClusteringAlgorithmProvider > algorithms = context .getAlgorithms ();
823
-
824
- final String algorithmId = requireNonNullElse (
825
- clusteringRequest .getAlgorithm (), algorithms .keySet ().iterator ().next ());
826
-
827
- ClusteringAlgorithmProvider provider = algorithms .get (algorithmId );
828
- if (provider == null ) {
829
- listener .onFailure (new IllegalArgumentException ("No such algorithm: " + algorithmId ));
830
- return ;
831
- }
832
-
833
- /*
834
- * We're not a threaded listener so we're running on the search thread. This
835
- * is good -- we don't want to serve more clustering requests than we can handle
836
- * anyway.
837
- */
838
- ClusteringAlgorithm algorithm = provider .get ();
839
-
840
- try {
841
- Map <String , Object > requestAttrs = clusteringRequest .getAttributes ();
842
- if (requestAttrs != null ) {
843
- Attrs .populate (algorithm , requestAttrs );
844
- }
845
-
846
- String queryHint = clusteringRequest .getQueryHint ();
847
- if (queryHint != null ) {
848
- algorithm .accept (new OptionalQueryHintSetterVisitor (clusteringRequest .getQueryHint ()));
822
+ final long tsSearchEnd = System .nanoTime ();
823
+
824
+ LinkedHashMap <String , ClusteringAlgorithmProvider > algorithms =
825
+ context .getAlgorithms ();
826
+
827
+ final String algorithmId =
828
+ requireNonNullElse (
829
+ clusteringRequest .getAlgorithm (), algorithms .keySet ().iterator ().next ());
830
+
831
+ ClusteringAlgorithmProvider provider = algorithms .get (algorithmId );
832
+ if (provider == null ) {
833
+ listener .onFailure (
834
+ new IllegalArgumentException ("No such algorithm: " + algorithmId ));
835
+ return ;
836
+ }
837
+
838
+ /*
839
+ * We're not a threaded listener so we're running on the search thread. This
840
+ * is good -- we don't want to serve more clustering requests than we can handle
841
+ * anyway.
842
+ */
843
+ ClusteringAlgorithm algorithm = provider .get ();
844
+
845
+ try {
846
+ Map <String , Object > requestAttrs = clusteringRequest .getAttributes ();
847
+ if (requestAttrs != null ) {
848
+ Attrs .populate (algorithm , requestAttrs );
849
+ }
850
+
851
+ String queryHint = clusteringRequest .getQueryHint ();
852
+ if (queryHint != null ) {
853
+ algorithm .accept (
854
+ new OptionalQueryHintSetterVisitor (clusteringRequest .getQueryHint ()));
855
+ }
856
+
857
+ List <InputDocument > documents =
858
+ prepareDocumentsForClustering (clusteringRequest , response );
859
+
860
+ String defaultLanguage = clusteringRequest .getDefaultLanguage ();
861
+ if (!context .isLanguageSupported (defaultLanguage )) {
862
+ throw new RuntimeException (
863
+ "The requested default language is not supported: '" + defaultLanguage + "'" );
864
+ }
865
+
866
+ // Split documents into language groups.
867
+ Map <String , List <InputDocument >> documentsByLanguage =
868
+ documents .stream ()
869
+ .collect (
870
+ Collectors .groupingBy (
871
+ doc -> {
872
+ String lang = doc .language ();
873
+ return lang == null ? defaultLanguage : lang ;
874
+ }));
875
+
876
+ // Run clustering.
877
+ long tsClusteringTotal = 0 ;
878
+ HashSet <String > warnOnce = new HashSet <>();
879
+ LinkedHashMap <String , List <Cluster <InputDocument >>> clustersByLanguage =
880
+ new LinkedHashMap <>();
881
+ for (Map .Entry <String , List <InputDocument >> e : documentsByLanguage .entrySet ()) {
882
+ String lang = e .getKey ();
883
+ if (!context .isLanguageSupported (lang )) {
884
+ if (warnOnce .add (lang )) {
885
+ logger .warn (
886
+ "Language is not supported, documents in this "
887
+ + "language will not be clustered: '"
888
+ + lang
889
+ + "'" );
890
+ }
891
+ } else {
892
+ LanguageComponents languageComponents = context .getLanguageComponents (lang );
893
+ final long tsClusteringStart = System .nanoTime ();
894
+ clustersByLanguage .put (
895
+ lang , algorithm .cluster (e .getValue ().stream (), languageComponents ));
896
+ final long tsClusteringEnd = System .nanoTime ();
897
+ tsClusteringTotal += (tsClusteringEnd - tsClusteringStart );
849
898
}
850
-
851
- List <InputDocument > documents = prepareDocumentsForClustering (clusteringRequest , response );
852
-
853
- String defaultLanguage = clusteringRequest .getDefaultLanguage ();
854
- if (!context .isLanguageSupported (defaultLanguage )) {
855
- throw new RuntimeException ("The requested default language is not supported: '" + defaultLanguage + "'" );
856
- }
857
-
858
- // Split documents into language groups.
859
- Map <String , List <InputDocument >> documentsByLanguage = documents .stream ()
860
- .collect (Collectors .groupingBy (
861
- doc -> {
862
- String lang = doc .language ();
863
- return lang == null ? defaultLanguage : lang ;
864
- }
865
- ));
866
-
867
- // Run clustering.
868
- long tsClusteringTotal = 0 ;
869
- HashSet <String > warnOnce = new HashSet <>();
870
- LinkedHashMap <String , List <Cluster <InputDocument >>> clustersByLanguage = new LinkedHashMap <>();
871
- for (Map .Entry <String , List <InputDocument >> e : documentsByLanguage .entrySet ()) {
872
- String lang = e .getKey ();
873
- if (!context .isLanguageSupported (lang )) {
874
- if (warnOnce .add (lang )) {
875
- logger .warn ("Language is not supported, documents in this " +
876
- "language will not be clustered: '" + lang + "'" );
877
- }
878
- } else {
879
- LanguageComponents languageComponents = context .getLanguageComponents (lang );
880
- final long tsClusteringStart = System .nanoTime ();
881
- clustersByLanguage .put (lang , algorithm .cluster (e .getValue ().stream (), languageComponents ));
882
- final long tsClusteringEnd = System .nanoTime ();
883
- tsClusteringTotal += (tsClusteringEnd - tsClusteringStart );
884
- }
885
- }
886
-
887
- final Map <String , String > info = new LinkedHashMap <>();
888
- info .put (ClusteringActionResponse .Fields .Info .ALGORITHM ,
889
- algorithmId );
890
- info .put (ClusteringActionResponse .Fields .Info .SEARCH_MILLIS ,
891
- Long .toString (TimeUnit .NANOSECONDS .toMillis (tsSearchEnd - tsSearchStart )));
892
- info .put (ClusteringActionResponse .Fields .Info .CLUSTERING_MILLIS ,
893
- Long .toString (TimeUnit .NANOSECONDS .toMillis (tsClusteringTotal )));
894
- info .put (ClusteringActionResponse .Fields .Info .TOTAL_MILLIS ,
895
- Long .toString (TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - tsSearchStart )));
896
- info .put (ClusteringActionResponse .Fields .Info .INCLUDE_HITS ,
897
- Boolean .toString (clusteringRequest .getIncludeHits ()));
898
- info .put (ClusteringActionResponse .Fields .Info .MAX_HITS ,
899
- clusteringRequest .getMaxHits () == Integer .MAX_VALUE ?
900
- "" : Integer .toString (clusteringRequest .getMaxHits ()));
901
- info .put (ClusteringActionResponse .Fields .Info .LANGUAGES ,
902
- String .join (", " , clustersByLanguage .keySet ()));
903
-
904
- // Trim search response's hits if we need to.
905
- if (clusteringRequest .getMaxHits () != Integer .MAX_VALUE ) {
906
- response = filterMaxHits (response , clusteringRequest .getMaxHits ());
907
- }
908
-
909
- AtomicInteger groupId = new AtomicInteger ();
910
- Map <String , DocumentGroup []> adaptedByLanguage = clustersByLanguage .entrySet ()
911
- .stream ()
912
- .filter (e -> !e .getValue ().isEmpty ())
913
- .collect (Collectors .toMap (
914
- Map .Entry ::getKey ,
915
- e -> adapt (e .getValue (), groupId )
916
- ));
917
-
918
- final ArrayList <DocumentGroup > groups = new ArrayList <>();
919
- adaptedByLanguage .values ()
920
- .forEach (langClusters -> groups .addAll (Arrays .asList (langClusters )));
921
-
922
- if (adaptedByLanguage .size () > 1 ) {
923
- groups .sort ((a , b ) -> Integer .compare (b .uniqueDocuments ().size (), a .uniqueDocuments ().size ()));
924
- }
925
-
926
- if (clusteringRequest .createUngroupedDocumentsCluster ) {
927
- DocumentGroup ungrouped = new DocumentGroup ();
928
- ungrouped .setId (groupId .incrementAndGet ());
929
- ungrouped .setPhrases (new String []{"Ungrouped documents" });
930
- ungrouped .setUngroupedDocuments (true );
931
- ungrouped .setScore (0d );
932
-
933
- LinkedHashSet <InputDocument > ungroupedDocuments = new LinkedHashSet <>(documents );
934
- clustersByLanguage .values ().forEach (
935
- langClusters -> removeReferenced (ungroupedDocuments , langClusters ));
936
- ungrouped .setDocumentReferences (
937
- ungroupedDocuments .stream ().map (InputDocument ::getStringId ).toArray (String []::new ));
938
-
939
- groups .add (ungrouped );
940
- }
941
-
942
- listener .onResponse (
943
- new ClusteringActionResponse (response , groups .toArray (new DocumentGroup [0 ]), info ));
944
- } catch (Exception e ) {
945
- // Log a full stack trace with all nested exceptions but only return
946
- // ElasticSearchException exception with a simple String (otherwise
947
- // clients cannot deserialize exception classes).
948
- String message = "Clustering error: " + e .getMessage ();
949
- logger .warn (message , e );
950
- listener .onFailure (new ElasticsearchException (message ));
951
- }
899
+ }
900
+
901
+ final Map <String , String > info = new LinkedHashMap <>();
902
+ info .put (ClusteringActionResponse .Fields .Info .ALGORITHM , algorithmId );
903
+ info .put (
904
+ ClusteringActionResponse .Fields .Info .SEARCH_MILLIS ,
905
+ Long .toString (TimeUnit .NANOSECONDS .toMillis (tsSearchEnd - tsSearchStart )));
906
+ info .put (
907
+ ClusteringActionResponse .Fields .Info .CLUSTERING_MILLIS ,
908
+ Long .toString (TimeUnit .NANOSECONDS .toMillis (tsClusteringTotal )));
909
+ info .put (
910
+ ClusteringActionResponse .Fields .Info .TOTAL_MILLIS ,
911
+ Long .toString (
912
+ TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - tsSearchStart )));
913
+ info .put (
914
+ ClusteringActionResponse .Fields .Info .INCLUDE_HITS ,
915
+ Boolean .toString (clusteringRequest .getIncludeHits ()));
916
+ info .put (
917
+ ClusteringActionResponse .Fields .Info .MAX_HITS ,
918
+ clusteringRequest .getMaxHits () == Integer .MAX_VALUE
919
+ ? ""
920
+ : Integer .toString (clusteringRequest .getMaxHits ()));
921
+ info .put (
922
+ ClusteringActionResponse .Fields .Info .LANGUAGES ,
923
+ String .join (", " , clustersByLanguage .keySet ()));
924
+
925
+ // Trim search response's hits if we need to.
926
+ if (clusteringRequest .getMaxHits () != Integer .MAX_VALUE ) {
927
+ response = filterMaxHits (response , clusteringRequest .getMaxHits ());
928
+ }
929
+
930
+ AtomicInteger groupId = new AtomicInteger ();
931
+ Map <String , DocumentGroup []> adaptedByLanguage =
932
+ clustersByLanguage .entrySet ().stream ()
933
+ .filter (e -> !e .getValue ().isEmpty ())
934
+ .collect (
935
+ Collectors .toMap (Map .Entry ::getKey , e -> adapt (e .getValue (), groupId )));
936
+
937
+ final ArrayList <DocumentGroup > groups = new ArrayList <>();
938
+ adaptedByLanguage
939
+ .values ()
940
+ .forEach (langClusters -> groups .addAll (Arrays .asList (langClusters )));
941
+
942
+ if (adaptedByLanguage .size () > 1 ) {
943
+ groups .sort (
944
+ (a , b ) ->
945
+ Integer .compare (b .uniqueDocuments ().size (), a .uniqueDocuments ().size ()));
946
+ }
947
+
948
+ if (clusteringRequest .createUngroupedDocumentsCluster ) {
949
+ DocumentGroup ungrouped = new DocumentGroup ();
950
+ ungrouped .setId (groupId .incrementAndGet ());
951
+ ungrouped .setPhrases (new String [] {"Ungrouped documents" });
952
+ ungrouped .setUngroupedDocuments (true );
953
+ ungrouped .setScore (0d );
954
+
955
+ LinkedHashSet <InputDocument > ungroupedDocuments = new LinkedHashSet <>(documents );
956
+ clustersByLanguage
957
+ .values ()
958
+ .forEach (langClusters -> removeReferenced (ungroupedDocuments , langClusters ));
959
+ ungrouped .setDocumentReferences (
960
+ ungroupedDocuments .stream ()
961
+ .map (InputDocument ::getStringId )
962
+ .toArray (String []::new ));
963
+
964
+ groups .add (ungrouped );
965
+ }
966
+
967
+ listener .onResponse (
968
+ new ClusteringActionResponse (
969
+ response , groups .toArray (new DocumentGroup [0 ]), info ));
970
+ } catch (Exception e ) {
971
+ // Log a full stack trace with all nested exceptions but only return
972
+ // ElasticSearchException exception with a simple String (otherwise
973
+ // clients cannot deserialize exception classes).
974
+ String message = "Clustering error: " + e .getMessage ();
975
+ logger .warn (message , e );
976
+ listener .onFailure (new ElasticsearchException (message ));
977
+ }
952
978
}
953
979
954
- private void removeReferenced (LinkedHashSet <InputDocument > ungrouped , List <Cluster <InputDocument >> clusters ) {
955
- clusters .forEach (cluster -> {
956
- ungrouped .removeAll (cluster .getDocuments ());
957
- removeReferenced (ungrouped , cluster .getClusters ());
958
- });
980
+ private void removeReferenced (
981
+ LinkedHashSet <InputDocument > ungrouped , List <Cluster <InputDocument >> clusters ) {
982
+ clusters .forEach (
983
+ cluster -> {
984
+ ungrouped .removeAll (cluster .getDocuments ());
985
+ removeReferenced (ungrouped , cluster .getClusters ());
986
+ });
959
987
}
960
- });
988
+ });
961
989
}
962
990
963
991
public static <T > T requireNonNullElse (T first , T def ) {
0 commit comments