@@ -693,11 +693,13 @@ enum SamplerCommand {
693
693
Continue ,
694
694
Progress ,
695
695
Flush ,
696
+ Inspect ,
696
697
}
697
698
698
- enum SamplerResponse {
699
+ enum SamplerResponse < T : Send + ' static > {
699
700
Ok ( ) ,
700
701
Progress ( Box < [ ChainProgress ] > ) ,
702
+ Inspect ( T ) ,
701
703
}
702
704
703
705
pub enum SamplerWaitResult < F : Send + ' static > {
@@ -709,7 +711,7 @@ pub enum SamplerWaitResult<F: Send + 'static> {
709
711
pub struct Sampler < F : Send + ' static > {
710
712
main_thread : JoinHandle < Result < ( Option < anyhow:: Error > , F ) > > ,
711
713
commands : SyncSender < SamplerCommand > ,
712
- responses : Receiver < SamplerResponse > ,
714
+ responses : Receiver < SamplerResponse < ( Option < anyhow :: Error > , F ) > > ,
713
715
results : Receiver < Result < ( ) > > ,
714
716
}
715
717
@@ -827,7 +829,11 @@ impl<F: Send + 'static> Sampler<F> {
827
829
pause_start = Instant :: now ( ) ;
828
830
}
829
831
is_paused = true ;
830
- responses_tx. send ( SamplerResponse :: Ok ( ) ) ?;
832
+ responses_tx. send ( SamplerResponse :: Ok ( ) ) . map_err ( |e| {
833
+ anyhow:: anyhow!(
834
+ "Could not send pause response to controller thread: {e}"
835
+ )
836
+ } ) ?;
831
837
}
832
838
Ok ( SamplerCommand :: Continue ) => {
833
839
for chain in chains. iter ( ) {
@@ -837,18 +843,50 @@ impl<F: Send + 'static> Sampler<F> {
837
843
}
838
844
pause_time += pause_start. elapsed ( ) ;
839
845
is_paused = false ;
840
- responses_tx. send ( SamplerResponse :: Ok ( ) ) ?;
846
+ responses_tx. send ( SamplerResponse :: Ok ( ) ) . map_err ( |e| {
847
+ anyhow:: anyhow!(
848
+ "Could not send continue response to controller thread: {e}"
849
+ )
850
+ } ) ?;
841
851
}
842
852
Ok ( SamplerCommand :: Progress ) => {
843
853
let progress =
844
854
chains. iter ( ) . map ( |chain| chain. progress ( ) ) . collect_vec ( ) ;
845
- responses_tx. send ( SamplerResponse :: Progress ( progress. into ( ) ) ) ?;
855
+ responses_tx. send ( SamplerResponse :: Progress ( progress. into ( ) ) ) . map_err ( |e| {
856
+ anyhow:: anyhow!(
857
+ "Could not send progress response to controller thread: {e}"
858
+ )
859
+ } ) ?;
860
+ }
861
+ Ok ( SamplerCommand :: Inspect ) => {
862
+ let traces = chains
863
+ . iter ( )
864
+ . map ( |chain| {
865
+ chain
866
+ . trace
867
+ . lock ( )
868
+ . expect ( "Poisoned lock" )
869
+ . as_ref ( )
870
+ . map ( |v| v. inspect ( ) )
871
+ } )
872
+ . flatten ( )
873
+ . collect_vec ( ) ;
874
+ let finalized_trace = trace. inspect ( traces) ?;
875
+ responses_tx. send ( SamplerResponse :: Inspect ( finalized_trace) ) . map_err ( |e| {
876
+ anyhow:: anyhow!(
877
+ "Could not send inspect response to controller thread: {e}"
878
+ )
879
+ } ) ?;
846
880
}
847
881
Ok ( SamplerCommand :: Flush ) => {
848
882
for chain in chains. iter ( ) {
849
883
chain. flush ( ) ?;
850
884
}
851
- responses_tx. send ( SamplerResponse :: Ok ( ) ) ?;
885
+ responses_tx. send ( SamplerResponse :: Ok ( ) ) . map_err ( |e| {
886
+ anyhow:: anyhow!(
887
+ "Could not send flush response to controller thread: {e}"
888
+ )
889
+ } ) ?;
852
890
}
853
891
Err ( RecvTimeoutError :: Timeout ) => { }
854
892
Err ( RecvTimeoutError :: Disconnected ) => {
@@ -919,6 +957,18 @@ impl<F: Send + 'static> Sampler<F> {
919
957
Ok ( ( ) )
920
958
}
921
959
960
+ pub fn inspect ( & mut self ) -> Result < ( Option < anyhow:: Error > , F ) > {
961
+ self . commands . send ( SamplerCommand :: Inspect ) ?;
962
+ let response = self
963
+ . responses
964
+ . recv ( )
965
+ . context ( "Could not recieve inspect response from controller thread" ) ?;
966
+ let SamplerResponse :: Inspect ( trace) = response else {
967
+ bail ! ( "Got invalid response from sample controller thread" ) ;
968
+ } ;
969
+ Ok ( trace)
970
+ }
971
+
922
972
pub fn abort ( self ) -> Result < ( Option < anyhow:: Error > , F ) > {
923
973
drop ( self . commands ) ;
924
974
let result = self . main_thread . join ( ) ;
0 commit comments