11import * as path from "path"
2+ import * as fs from "fs/promises"
23
34import type { MockedFunction } from "vitest"
45
@@ -10,6 +11,7 @@ import { unescapeHtmlEntities } from "../../../utils/text-normalization"
1011import { everyLineHasLineNumbers , stripLineNumbers } from "../../../integrations/misc/extract-text"
1112import { ToolUse , ToolResponse } from "../../../shared/tools"
1213import { writeToFileTool } from "../writeToFileTool"
14+ import { experiments , EXPERIMENT_IDS } from "../../../shared/experiments"
1315
1416vi . mock ( "path" , async ( ) => {
1517 const originalPath = await vi . importActual ( "path" )
@@ -27,6 +29,19 @@ vi.mock("delay", () => ({
2729 default : vi . fn ( ) ,
2830} ) )
2931
32+ vi . mock ( "fs/promises" , ( ) => ( {
33+ readFile : vi . fn ( ) . mockResolvedValue ( "original content" ) ,
34+ } ) )
35+
36+ vi . mock ( "../../../shared/experiments" , ( ) => ( {
37+ experiments : {
38+ isEnabled : vi . fn ( ) . mockReturnValue ( false ) , // Default to disabled
39+ } ,
40+ EXPERIMENT_IDS : {
41+ PREVENT_FOCUS_DISRUPTION : "preventFocusDisruption" ,
42+ } ,
43+ } ) )
44+
3045vi . mock ( "../../../utils/fs" , ( ) => ( {
3146 fileExistsAtPath : vi . fn ( ) . mockResolvedValue ( false ) ,
3247} ) )
@@ -108,6 +123,8 @@ describe("writeToFileTool", () => {
108123 const mockedEveryLineHasLineNumbers = everyLineHasLineNumbers as MockedFunction < typeof everyLineHasLineNumbers >
109124 const mockedStripLineNumbers = stripLineNumbers as MockedFunction < typeof stripLineNumbers >
110125 const mockedPathResolve = path . resolve as MockedFunction < typeof path . resolve >
126+ const mockedExperimentsIsEnabled = experiments . isEnabled as MockedFunction < typeof experiments . isEnabled >
127+ const mockedFsReadFile = fs . readFile as MockedFunction < typeof fs . readFile >
111128
112129 const mockCline : any = { }
113130 let mockAskApproval : ReturnType < typeof vi . fn >
@@ -127,6 +144,8 @@ describe("writeToFileTool", () => {
127144 mockedUnescapeHtmlEntities . mockImplementation ( ( content ) => content )
128145 mockedEveryLineHasLineNumbers . mockReturnValue ( false )
129146 mockedStripLineNumbers . mockImplementation ( ( content ) => content )
147+ mockedExperimentsIsEnabled . mockReturnValue ( false ) // Default to disabled
148+ mockedFsReadFile . mockResolvedValue ( "original content" )
130149
131150 mockCline . cwd = "/"
132151 mockCline . consecutiveMistakeCount = 0
@@ -416,4 +435,98 @@ describe("writeToFileTool", () => {
416435 expect ( mockCline . diffViewProvider . reset ) . toHaveBeenCalled ( )
417436 } )
418437 } )
438+
439+ describe ( "PREVENT_FOCUS_DISRUPTION experiment" , ( ) => {
440+ beforeEach ( ( ) => {
441+ // Reset the experiments mock for these tests
442+ mockedExperimentsIsEnabled . mockReset ( )
443+ } )
444+
445+ it ( "should NOT save file before user approval when experiment is enabled" , async ( ) => {
446+ // Enable the PREVENT_FOCUS_DISRUPTION experiment
447+ mockedExperimentsIsEnabled . mockReturnValue ( true )
448+
449+ mockCline . providerRef . deref ( ) . getState . mockResolvedValue ( {
450+ diagnosticsEnabled : true ,
451+ writeDelayMs : 1000 ,
452+ experiments : {
453+ preventFocusDisruption : true ,
454+ } ,
455+ } )
456+
457+ // Mock saveDirectly to track when it's called
458+ const saveDirectlySpy = vi . fn ( ) . mockResolvedValue ( {
459+ newProblemsMessage : "" ,
460+ userEdits : undefined ,
461+ finalContent : testContent ,
462+ } )
463+ mockCline . diffViewProvider . saveDirectly = saveDirectlySpy
464+
465+ // User rejects the approval
466+ mockAskApproval . mockResolvedValue ( false )
467+
468+ await executeWriteFileTool ( { } , { fileExists : false } )
469+
470+ // Verify that askApproval was called
471+ expect ( mockAskApproval ) . toHaveBeenCalled ( )
472+
473+ // Verify that saveDirectly was NOT called since user rejected
474+ expect ( saveDirectlySpy ) . not . toHaveBeenCalled ( )
475+
476+ // Verify that the diffViewProvider state was reset
477+ expect ( mockCline . diffViewProvider . editType ) . toBe ( undefined )
478+ expect ( mockCline . diffViewProvider . originalContent ) . toBe ( undefined )
479+ } )
480+
481+ it ( "should save file AFTER user approval when experiment is enabled" , async ( ) => {
482+ // Enable the PREVENT_FOCUS_DISRUPTION experiment
483+ mockedExperimentsIsEnabled . mockReturnValue ( true )
484+
485+ mockCline . providerRef . deref ( ) . getState . mockResolvedValue ( {
486+ diagnosticsEnabled : true ,
487+ writeDelayMs : 1000 ,
488+ experiments : {
489+ preventFocusDisruption : true ,
490+ } ,
491+ } )
492+
493+ // Mock saveDirectly to track when it's called
494+ const saveDirectlySpy = vi . fn ( ) . mockResolvedValue ( {
495+ newProblemsMessage : "" ,
496+ userEdits : undefined ,
497+ finalContent : testContent ,
498+ } )
499+ mockCline . diffViewProvider . saveDirectly = saveDirectlySpy
500+
501+ // Mock pushToolWriteResult
502+ mockCline . diffViewProvider . pushToolWriteResult = vi . fn ( ) . mockResolvedValue ( "Tool result message" )
503+
504+ // User approves
505+ mockAskApproval . mockResolvedValue ( true )
506+
507+ // Track the order of calls
508+ const callOrder : string [ ] = [ ]
509+ mockAskApproval . mockImplementation ( async ( ) => {
510+ callOrder . push ( "askApproval" )
511+ return true
512+ } )
513+ saveDirectlySpy . mockImplementation ( async ( ) => {
514+ callOrder . push ( "saveDirectly" )
515+ return {
516+ newProblemsMessage : "" ,
517+ userEdits : undefined ,
518+ finalContent : testContent ,
519+ }
520+ } )
521+
522+ await executeWriteFileTool ( { } , { fileExists : false } )
523+
524+ // Verify that askApproval was called BEFORE saveDirectly
525+ expect ( callOrder ) . toEqual ( [ "askApproval" , "saveDirectly" ] )
526+
527+ // Verify both were called
528+ expect ( mockAskApproval ) . toHaveBeenCalled ( )
529+ expect ( saveDirectlySpy ) . toHaveBeenCalledWith ( testFilePath , testContent , false , true , 1000 )
530+ } )
531+ } )
419532} )
0 commit comments