@@ -567,18 +567,65 @@ describe('Channel', () => {
567567 const { channel, chatClient } = await initClient ( ) ;
568568 const threadMessage = messages [ 0 ] ;
569569 const hasThread = jest . fn ( ) ;
570+ const hasThreadInstance = jest . fn ( ) ;
571+ const mockThreadInstance = {
572+ threadInstanceMock : true ,
573+ registerSubscriptions : jest . fn ( ) ,
574+ } ;
575+ const getThreadSpy = jest
576+ . spyOn ( chatClient , 'getThread' )
577+ . mockResolvedValueOnce ( mockThreadInstance ) ;
570578
571579 // this renders Channel, calls openThread from a child context consumer with a message,
572580 // and then calls hasThread with the thread id if it was set.
573- await renderComponent ( { channel, chatClient } , ( { openThread, thread } ) => {
574- if ( ! thread ) {
575- openThread ( threadMessage , { preventDefault : ( ) => null } ) ;
576- } else {
577- hasThread ( thread . id ) ;
578- }
581+ await renderComponent (
582+ { channel, chatClient } ,
583+ ( { openThread, thread, threadInstance } ) => {
584+ if ( ! thread ) {
585+ openThread ( threadMessage , { preventDefault : ( ) => null } ) ;
586+ } else {
587+ hasThread ( thread . id ) ;
588+ hasThreadInstance ( threadInstance ) ;
589+ }
590+ } ,
591+ ) ;
592+
593+ await waitFor ( ( ) => {
594+ expect ( hasThread ) . toHaveBeenCalledWith ( threadMessage . id ) ;
595+ expect ( getThreadSpy ) . not . toHaveBeenCalled ( ) ;
596+ expect ( hasThreadInstance ) . toHaveBeenCalledWith ( undefined ) ;
579597 } ) ;
598+ getThreadSpy . mockRestore ( ) ;
599+ } ) ;
600+
601+ it ( 'uses Thread instance when messageDraftsEnabled is true' , async ( ) => {
602+ const { channel, chatClient } = await initClient ( ) ;
603+ const threadMessage = messages [ 0 ] ;
604+ const hasThreadInstance = jest . fn ( ) ;
605+ const mockThreadInstance = {
606+ threadInstanceMock : true ,
607+ registerSubscriptions : jest . fn ( ) ,
608+ } ;
609+ const spy = jest
610+ . spyOn ( chatClient , 'getThread' )
611+ . mockResolvedValueOnce ( mockThreadInstance ) ;
580612
581- await waitFor ( ( ) => expect ( hasThread ) . toHaveBeenCalledWith ( threadMessage . id ) ) ;
613+ await renderComponent (
614+ { channel, chatClient, messageDraftsEnabled : true } ,
615+ ( { openThread, thread, threadInstance } ) => {
616+ if ( ! thread ) {
617+ openThread ( threadMessage , { preventDefault : ( ) => null } ) ;
618+ } else {
619+ hasThreadInstance ( threadInstance ) ;
620+ }
621+ } ,
622+ ) ;
623+
624+ await waitFor ( ( ) => {
625+ expect ( hasThreadInstance ) . toHaveBeenCalledWith ( mockThreadInstance ) ;
626+ expect ( mockThreadInstance . registerSubscriptions ) . toHaveBeenCalledWith ( ) ;
627+ } ) ;
628+ spy . mockRestore ( ) ;
582629 } ) ;
583630
584631 it ( 'should be able to load more messages in a thread until reaching the end' , async ( ) => {
0 commit comments