diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index 503c5fce63..92e92d7411 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -27,11 +27,15 @@ + + + - diff --git a/app/src/main/java/com/amaze/filemanager/ui/dialogs/SmbSearchDialog.kt b/app/src/main/java/com/amaze/filemanager/ui/dialogs/SmbSearchDialog.kt index 0cd1a20f64..c902f06512 100644 --- a/app/src/main/java/com/amaze/filemanager/ui/dialogs/SmbSearchDialog.kt +++ b/app/src/main/java/com/amaze/filemanager/ui/dialogs/SmbSearchDialog.kt @@ -23,7 +23,6 @@ package com.amaze.filemanager.ui.dialogs import android.app.Activity import android.app.Dialog import android.content.Context -import android.graphics.Color import android.os.Bundle import android.view.LayoutInflater import android.view.View @@ -31,6 +30,7 @@ import android.view.ViewGroup import android.widget.Toast import androidx.appcompat.widget.AppCompatImageView import androidx.appcompat.widget.AppCompatTextView +import androidx.core.graphics.toColorInt import androidx.fragment.app.DialogFragment import androidx.lifecycle.MutableLiveData import androidx.lifecycle.ViewModel @@ -142,11 +142,8 @@ class SmbSearchDialog : DialogFragment() { context: Context, ) : RecyclerView.Adapter() { private val items: MutableList = ArrayList() - private val mInflater: LayoutInflater - - init { - mInflater = context.getSystemService(Activity.LAYOUT_INFLATER_SERVICE) as LayoutInflater - } + private val inflater: LayoutInflater = + context.getSystemService(Activity.LAYOUT_INFLATER_SERVICE) as LayoutInflater /** * Called by [ComputerParcelableViewModel], add found computer to list view @@ -194,12 +191,12 @@ class SmbSearchDialog : DialogFragment() { val view: View return when (viewType) { VIEW_PROGRESSBAR -> { - view = mInflater.inflate(R.layout.smb_progress_row, parent, false) + view = inflater.inflate(R.layout.smb_progress_row, parent, false) ViewHolder(view) } else -> { view = - mInflater.inflate(R.layout.smb_computers_row, parent, false) + inflater.inflate(R.layout.smb_computers_row, parent, false) ElementViewHolder(view) } } @@ -229,7 +226,7 @@ class SmbSearchDialog : DialogFragment() { holder.txtTitle.text = name holder.image.setImageResource(R.drawable.ic_settings_remote_white_48dp) if (utilsProvider.appTheme == AppTheme.LIGHT) { - holder.image.setColorFilter(Color.parseColor("#666666")) + holder.image.setColorFilter("#666666".toColorInt()) } holder.txtDesc.text = addr } diff --git a/app/src/main/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategy.kt b/app/src/main/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategy.kt new file mode 100644 index 0000000000..f59e23e775 --- /dev/null +++ b/app/src/main/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategy.kt @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2014-2024 Arpit Khurana , Vishal Nehra , + * Emmanuel Messulam, Raymond Lai and Contributors. + * + * This file is part of Amaze File Manager. + * + * Amaze File Manager is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package com.amaze.filemanager.utils.smb + +import android.content.Context.NSD_SERVICE +import android.content.Context.WIFI_SERVICE +import android.net.nsd.NsdManager +import android.net.nsd.NsdServiceInfo +import android.net.wifi.WifiManager +import android.os.Build.VERSION.SDK_INT +import android.os.Build.VERSION_CODES.UPSIDE_DOWN_CAKE +import com.amaze.filemanager.application.AppConfig +import com.amaze.filemanager.utils.ComputerParcelable +import org.slf4j.Logger +import org.slf4j.LoggerFactory + +/** + * [SmbDeviceScannerObservable.DiscoverDeviceStrategy] implementation using Android's + * [NsdManager] to discover SMB devices using mDNS/Bonjour/ZeroConf. + * + * @see SmbDeviceScannerObservable + * @see NsdManager + * + */ +class NsdManagerDiscoverDeviceStrategy : SmbDeviceScannerObservable.DiscoverDeviceStrategy { + companion object { + internal const val SERVICE_TYPE_SMB = "_smb._tcp." + private val logger: Logger = + LoggerFactory.getLogger(NsdManagerDiscoverDeviceStrategy::class.java) + } + + private val wifiManager: WifiManager = + AppConfig.getInstance().applicationContext.getSystemService(WIFI_SERVICE) as WifiManager + private val nsdManager: NsdManager = + AppConfig.getInstance().applicationContext.getSystemService(NSD_SERVICE) as NsdManager + + private var multicastLock: WifiManager.MulticastLock? = null + private var discoveryListener: NsdManager.DiscoveryListener? = null + + override fun discoverDevices(callback: (ComputerParcelable) -> Unit) { + multicastLock = + wifiManager.createMulticastLock("smb_mdns_discovery").apply { + setReferenceCounted(true) + } + multicastLock?.acquire() + discoveryListener = createDiscoveryListener(callback) + nsdManager.discoverServices( + SERVICE_TYPE_SMB, + NsdManager.PROTOCOL_DNS_SD, + discoveryListener, + ) + } + + override fun onCancel() { + discoveryListener?.let { + nsdManager.stopServiceDiscovery(it) + discoveryListener = null + } + multicastLock?.let { + if (it.isHeld) { + it.release() + } + } + } + + /** + * Creates a new [NsdManager.DiscoveryListener] to handle service discovery events. + * + * For backward compatibility, uses [NsdManager.ResolveListener] to resolve services + * and perform the callback. + */ + private fun createDiscoveryListener(callback: (ComputerParcelable) -> Unit): NsdManager.DiscoveryListener { + return object : NsdManager.DiscoveryListener { + override fun onServiceFound(serviceInfo: NsdServiceInfo) { + // Just to be sure. + if (serviceInfo.serviceType != SERVICE_TYPE_SMB) { + logger.warn("Unknown Service Type: ${serviceInfo.serviceType} for service: ${serviceInfo.serviceName}") + } else { + @Suppress("DEPRECATION") + nsdManager.resolveService( + serviceInfo, + object : NsdManager.ResolveListener { + override fun onServiceResolved(resolvedServiceInfo: NsdServiceInfo) { + val host = + if (SDK_INT >= UPSIDE_DOWN_CAKE) { + resolvedServiceInfo.hostAddresses.firstOrNull() + } else { + resolvedServiceInfo.host + } + if (host != null && host.hostAddress?.isNotEmpty() == true) { + val computer = + ComputerParcelable( + name = resolvedServiceInfo.serviceName, + addr = host.hostAddress!!, + ) + callback(computer) + } + } + + override fun onResolveFailed( + serviceInfo: NsdServiceInfo?, + errorCode: Int, + ) { + logger.error( + "Service resolve failed: ${serviceInfo?.serviceName} with error code: $errorCode", + ) + } + }, + ) + } + } + + override fun onServiceLost(serviceInfo: NsdServiceInfo?) { + logger.debug("Service lost: ${serviceInfo?.serviceName}") + } + + override fun onStartDiscoveryFailed( + serviceType: String, + errorCode: Int, + ) { + logger.error("Service discovery start failed: $serviceType with error code: $errorCode") + nsdManager.stopServiceDiscovery(this) + } + + override fun onStopDiscoveryFailed( + serviceType: String, + errorCode: Int, + ) { + logger.debug("Service discovery stop failed: $serviceType with error code: $errorCode") + nsdManager.stopServiceDiscovery(this) + } + + override fun onDiscoveryStarted(serviceType: String?) = logger.debug("Service discovery started: $serviceType") + + override fun onDiscoveryStopped(serviceType: String?) = logger.debug("Service discovery stopped: $serviceType") + } + } +} diff --git a/app/src/main/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservable.kt b/app/src/main/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservable.kt index 7e36143a4f..ba5dc244d7 100644 --- a/app/src/main/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservable.kt +++ b/app/src/main/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservable.kt @@ -22,7 +22,6 @@ package com.amaze.filemanager.utils.smb import androidx.annotation.VisibleForTesting import com.amaze.filemanager.utils.ComputerParcelable -import com.amaze.filemanager.utils.smb.SmbDeviceScannerObservable.DiscoverDeviceStrategy import io.reactivex.Observable import io.reactivex.Observer import io.reactivex.disposables.Disposable @@ -55,6 +54,7 @@ class SmbDeviceScannerObservable : Observable() { arrayOf( WsddDiscoverDeviceStrategy(), SameSubnetDiscoverDeviceStrategy(), + NsdManagerDiscoverDeviceStrategy(), ) @VisibleForTesting set @@ -86,19 +86,22 @@ class SmbDeviceScannerObservable : Observable() { */ override fun subscribeActual(observer: Observer) { this.observer = observer + observer.onSubscribe(Disposables.empty()) this.disposable = merge( discoverDeviceStrategies.map { strategy -> - fromCallable { + create { emitter -> strategy.discoverDevices { addr -> - observer.onNext(ComputerParcelable(addr.addr, addr.name)) + if (!emitter.isDisposed) { + emitter.onNext(addr) + } } + emitter.setCancellable { strategy.onCancel() } }.subscribeOn(Schedulers.io()) }, - ).observeOn(Schedulers.computation()).doOnComplete { - discoverDeviceStrategies.forEach { strategy -> - strategy.onCancel() - } - }.subscribe() + ).observeOn(Schedulers.computation()).subscribe( + { computer -> observer.onNext(computer) }, + { error -> observer.onError(error) }, + ) } } diff --git a/app/src/test/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategyTest.kt b/app/src/test/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategyTest.kt new file mode 100644 index 0000000000..b10f198141 --- /dev/null +++ b/app/src/test/java/com/amaze/filemanager/utils/smb/NsdManagerDiscoverDeviceStrategyTest.kt @@ -0,0 +1,702 @@ +/* + * Copyright (C) 2014-2024 Arpit Khurana , Vishal Nehra , + * Emmanuel Messulam, Raymond Lai and Contributors. + * + * This file is part of Amaze File Manager. + * + * Amaze File Manager is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package com.amaze.filemanager.utils.smb + +import android.content.Context +import android.net.nsd.NsdManager +import android.net.nsd.NsdServiceInfo +import android.net.wifi.WifiManager +import android.os.Build.VERSION_CODES.LOLLIPOP +import android.os.Build.VERSION_CODES.P +import android.os.Build.VERSION_CODES.R +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.amaze.filemanager.utils.ComputerParcelable +import com.amaze.filemanager.utils.smb.NsdManagerDiscoverDeviceStrategy.Companion.SERVICE_TYPE_SMB +import io.mockk.every +import io.mockk.mockk +import io.mockk.slot +import io.mockk.verify +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.annotation.Config +import java.net.InetAddress +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +/** + * Unit tests for [NsdManagerDiscoverDeviceStrategy]. + */ +@Suppress("LongClass", "StringLiteralDuplication") +@RunWith(AndroidJUnit4::class) +@Config(sdk = [LOLLIPOP, P, R]) +class NsdManagerDiscoverDeviceStrategyTest { + private lateinit var context: Context + private lateinit var mockNsdManager: NsdManager + private lateinit var mockWifiManager: WifiManager + private lateinit var mockMulticastLock: WifiManager.MulticastLock + + /** + * Set up mocks before each test. + */ + @Before + fun setUp() { + context = ApplicationProvider.getApplicationContext() + mockNsdManager = mockk(relaxed = true) + mockWifiManager = mockk(relaxed = true) + mockMulticastLock = mockk(relaxed = true) + + every { mockWifiManager.createMulticastLock(any()) } returns mockMulticastLock + every { mockMulticastLock.isHeld } returns true + } + + /** + * Clean up after each test. + */ + @After + fun tearDown() { + // Clean up if necessary + } + + /** + * Test that onServiceFound callback properly invokes the callback with discovered device. + */ + @Test + fun testOnServiceFoundInvokesCallback() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo for discovery + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns SERVICE_TYPE_SMB + + // Create a mock resolved NsdServiceInfo with a valid host + val mockResolvedServiceInfo = mockk() + val mockHost = mockk() + every { mockResolvedServiceInfo.serviceName } returns "TestServer" + every { mockResolvedServiceInfo.host } returns mockHost + every { mockHost.hostAddress } returns "192.168.1.100" + + // Mock resolveService to capture the ResolveListener and invoke onServiceResolved + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + val latch = CountDownLatch(1) + + strategy.discoverDevices { computer -> + result.add(computer) + latch.countDown() + } + + // Trigger onServiceFound + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Wait for callback + latch.await(1, TimeUnit.SECONDS) + + // Verify callback was invoked with correct data + assertEquals(1, result.size) + assertEquals("TestServer", result[0].name) + assertEquals("192.168.1.100", result[0].addr) + } + + /** + * Test that onServiceFound does not invoke callback when host is null. + */ + @Test + fun testOnServiceFoundDoesNotInvokeCallbackWhenHostIsNull() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo for discovery + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns SERVICE_TYPE_SMB + + // Create a mock resolved NsdServiceInfo with null host + val mockResolvedServiceInfo = mockk() + every { mockResolvedServiceInfo.serviceName } returns "TestServer" + every { mockResolvedServiceInfo.host } returns null + + // Mock resolveService to capture the ResolveListener and invoke onServiceResolved + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + + strategy.discoverDevices { computer -> + result.add(computer) + } + + // Trigger onServiceFound + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Verify callback was NOT invoked + assertEquals(0, result.size) + } + + /** + * Test that onServiceFound does not invoke callback when hostAddress is empty. + */ + @Test + fun testOnServiceFoundDoesNotInvokeCallbackWhenHostAddressIsEmpty() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo for discovery + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns SERVICE_TYPE_SMB + + // Create a mock resolved NsdServiceInfo with empty hostAddress + val mockResolvedServiceInfo = mockk() + val mockHost = mockk() + every { mockResolvedServiceInfo.serviceName } returns "TestServer" + every { mockResolvedServiceInfo.host } returns mockHost + every { mockHost.hostAddress } returns "" + + // Mock resolveService to capture the ResolveListener and invoke onServiceResolved + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + + strategy.discoverDevices { computer -> + result.add(computer) + } + + // Trigger onServiceFound + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Verify callback was NOT invoked + assertEquals(0, result.size) + } + + /** + * Test that onServiceFound does not invoke callback when hostAddress is null. + */ + @Test + fun testOnServiceFoundDoesNotInvokeCallbackWhenHostAddressIsNull() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo for discovery + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns SERVICE_TYPE_SMB + + // Create a mock resolved NsdServiceInfo with null hostAddress + val mockResolvedServiceInfo = mockk() + val mockHost = mockk() + every { mockResolvedServiceInfo.serviceName } returns "TestServer" + every { mockResolvedServiceInfo.host } returns mockHost + every { mockHost.hostAddress } returns null + + // Mock resolveService to capture the ResolveListener and invoke onServiceResolved + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + + strategy.discoverDevices { computer -> + result.add(computer) + } + + // Trigger onServiceFound + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Verify callback was NOT invoked + assertEquals(0, result.size) + } + + /** + * Test that discoverDevices acquires multicast lock and starts NSD discovery. + */ + @Test + fun testDiscoverDevicesStartsDiscovery() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + // Simulate discovery started callback + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + val latch = CountDownLatch(1) + + strategy.discoverDevices { computer -> + result.add(computer) + latch.countDown() + } + + // Verify that multicast lock was acquired + verify { mockMulticastLock.setReferenceCounted(true) } + verify { mockMulticastLock.acquire() } + + // Verify that discovery was started + verify { + mockNsdManager.discoverServices( + SERVICE_TYPE_SMB, + NsdManager.PROTOCOL_DNS_SD, + any(), + ) + } + } + + /** + * Test that onCancel stops NSD discovery and releases multicast lock. + */ + @Test + fun testOnCancelStopsDiscoveryAndReleasesLock() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + strategy.onCancel() + + // Verify that multicast lock was released + verify { mockMulticastLock.release() } + } + + /** + * Test that onCancel does not release lock if not held. + */ + @Test + fun testOnCancelDoesNotReleaseLockIfNotHeld() { + every { mockMulticastLock.isHeld } returns false + + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + strategy.onCancel() + + // Verify that release was not called since lock is not held + verify(exactly = 0) { mockMulticastLock.release() } + } + + /** + * Test onStartDiscoveryFailed callback stops discovery. + */ + @Test + fun testOnStartDiscoveryFailedStopsDiscovery() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + // Simulate discovery failed callback + listenerSlot.captured.onStartDiscoveryFailed( + SERVICE_TYPE_SMB, + NsdManager.FAILURE_INTERNAL_ERROR, + ) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + + // Verify that stopServiceDiscovery was called + verify { mockNsdManager.stopServiceDiscovery(any()) } + } + + /** + * Test onStopDiscoveryFailed callback stops discovery. + */ + @Test + fun testOnStopDiscoveryFailedStopsDiscovery() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + + // Manually trigger stop discovery failed + listenerSlot.captured.onStopDiscoveryFailed( + SERVICE_TYPE_SMB, + NsdManager.FAILURE_INTERNAL_ERROR, + ) + + // Verify that stopServiceDiscovery was called + verify { mockNsdManager.stopServiceDiscovery(any()) } + } + + /** + * Test discovery stopped callback logs correctly. + */ + @Test + fun testOnDiscoveryStoppedLogs() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + + // Should not throw - just logs + listenerSlot.captured.onDiscoveryStopped(SERVICE_TYPE_SMB) + } + + /** + * Test that onResolveFailed does not invoke callback and logs error. + */ + @Test + fun testOnResolveFailedDoesNotInvokeCallback() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo for discovery + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns SERVICE_TYPE_SMB + + // Mock resolveService to capture the ResolveListener and invoke onResolveFailed + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onResolveFailed(mockServiceInfo, NsdManager.FAILURE_INTERNAL_ERROR) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + + strategy.discoverDevices { computer -> + result.add(computer) + } + + // Trigger onServiceFound which will trigger onResolveFailed + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Verify callback was NOT invoked + assertEquals(0, result.size) + } + + /** + * Test that onServiceLost logs and does not crash. + */ + @Test + fun testOnServiceLostLogs() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + + // Create a mock NsdServiceInfo + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + + // Should not throw - just logs + listenerSlot.captured.onServiceLost(mockServiceInfo) + } + + /** + * Test that onServiceLost handles null serviceInfo. + */ + @Test + fun testOnServiceLostHandlesNull() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + val strategy = createStrategyWithMocks() + + strategy.discoverDevices { } + + // Should not throw - just logs + listenerSlot.captured.onServiceLost(null) + } + + /** + * Test that onServiceFound does not resolve service when serviceType does not match. + */ + @Test + fun testOnServiceFoundDoesNotResolveWhenServiceTypeDoesNotMatch() { + val listenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create a mock NsdServiceInfo with non-matching service type + val mockServiceInfo = mockk() + every { mockServiceInfo.serviceName } returns "TestServer" + every { mockServiceInfo.serviceType } returns "_http._tcp." + + val strategy = createStrategyWithMocks() + val result = ArrayList() + + strategy.discoverDevices { computer -> + result.add(computer) + } + + // Trigger onServiceFound with non-matching service type + listenerSlot.captured.onServiceFound(mockServiceInfo) + + // Verify resolveService was NOT called + verify(exactly = 0) { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(any(), any()) + } + + // Verify callback was NOT invoked + assertEquals(0, result.size) + } + + /** + * Test discovering multiple devices. + */ + @Test + fun testDiscoverMultipleDevices() { + val listenerSlot = slot() + val resolveListenerSlot = slot() + every { + mockNsdManager.discoverServices( + eq(SERVICE_TYPE_SMB), + eq(NsdManager.PROTOCOL_DNS_SD), + capture(listenerSlot), + ) + } answers { + listenerSlot.captured.onDiscoveryStarted(SERVICE_TYPE_SMB) + } + + // Create mock NsdServiceInfo for first device + val mockServiceInfo1 = mockk() + every { mockServiceInfo1.serviceName } returns "Server1" + every { mockServiceInfo1.serviceType } returns SERVICE_TYPE_SMB + + val mockResolvedServiceInfo1 = mockk() + val mockHost1 = mockk() + every { mockResolvedServiceInfo1.serviceName } returns "Server1" + every { mockResolvedServiceInfo1.host } returns mockHost1 + every { mockHost1.hostAddress } returns "192.168.1.100" + + // Create mock NsdServiceInfo for second device + val mockServiceInfo2 = mockk() + every { mockServiceInfo2.serviceName } returns "Server2" + every { mockServiceInfo2.serviceType } returns SERVICE_TYPE_SMB + + val mockResolvedServiceInfo2 = mockk() + val mockHost2 = mockk() + every { mockResolvedServiceInfo2.serviceName } returns "Server2" + every { mockResolvedServiceInfo2.host } returns mockHost2 + every { mockHost2.hostAddress } returns "192.168.1.101" + + // Mock resolveService for both devices + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo1), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo1) + } + + every { + @Suppress("DEPRECATION") + mockNsdManager.resolveService(eq(mockServiceInfo2), capture(resolveListenerSlot)) + } answers { + resolveListenerSlot.captured.onServiceResolved(mockResolvedServiceInfo2) + } + + val strategy = createStrategyWithMocks() + val result = ArrayList() + val latch = CountDownLatch(2) + + strategy.discoverDevices { computer -> + result.add(computer) + latch.countDown() + } + + // Trigger onServiceFound for both devices + listenerSlot.captured.onServiceFound(mockServiceInfo1) + listenerSlot.captured.onServiceFound(mockServiceInfo2) + + // Wait for callbacks + latch.await(1, TimeUnit.SECONDS) + + // Verify both callbacks were invoked + assertEquals(2, result.size) + assertEquals("Server1", result[0].name) + assertEquals("192.168.1.100", result[0].addr) + assertEquals("Server2", result[1].name) + assertEquals("192.168.1.101", result[1].addr) + } + + /** + * Helper method to create a strategy instance with mocked dependencies. + * Uses reflection to inject mocks since the class initializes managers in constructor. + */ + private fun createStrategyWithMocks(): NsdManagerDiscoverDeviceStrategy { + val strategy = NsdManagerDiscoverDeviceStrategy() + + // Use reflection to replace the private fields with our mocks + val wifiManagerField = + NsdManagerDiscoverDeviceStrategy::class.java + .getDeclaredField("wifiManager") + wifiManagerField.isAccessible = true + wifiManagerField.set(strategy, mockWifiManager) + + val nsdManagerField = + NsdManagerDiscoverDeviceStrategy::class.java + .getDeclaredField("nsdManager") + nsdManagerField.isAccessible = true + nsdManagerField.set(strategy, mockNsdManager) + + return strategy + } +} diff --git a/app/src/test/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservableTest.kt b/app/src/test/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservableTest.kt new file mode 100644 index 0000000000..b1186bd4f3 --- /dev/null +++ b/app/src/test/java/com/amaze/filemanager/utils/smb/SmbDeviceScannerObservableTest.kt @@ -0,0 +1,378 @@ +/* + * Copyright (C) 2014-2022 Arpit Khurana , Vishal Nehra , + * Emmanuel Messulam, Raymond Lai and Contributors. + * + * This file is part of Amaze File Manager. + * + * Amaze File Manager is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package com.amaze.filemanager.utils.smb + +import android.os.Build.VERSION_CODES.LOLLIPOP +import android.os.Build.VERSION_CODES.P +import android.os.Build.VERSION_CODES.R +import android.os.Looper +import androidx.test.ext.junit.runners.AndroidJUnit4 +import com.amaze.filemanager.utils.ComputerParcelable +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import io.reactivex.plugins.RxJavaPlugins +import io.reactivex.schedulers.Schedulers +import org.junit.After +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.Shadows.shadowOf +import org.robolectric.annotation.Config +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +/** + * Unit tests for [SmbDeviceScannerObservable]. + */ +@Suppress("LongClass", "StringLiteralDuplication") +@RunWith(AndroidJUnit4::class) +@Config(sdk = [LOLLIPOP, P, R]) +class SmbDeviceScannerObservableTest { + private lateinit var mockStrategy1: SmbDeviceScannerObservable.DiscoverDeviceStrategy + private lateinit var mockStrategy2: SmbDeviceScannerObservable.DiscoverDeviceStrategy + + /** + * Set up mocks before each test. + */ + @Before + fun setUp() { + mockStrategy1 = mockk(relaxed = true) + mockStrategy2 = mockk(relaxed = true) + + // Override RxJava schedulers to use trampoline for testing + RxJavaPlugins.setIoSchedulerHandler { Schedulers.trampoline() } + RxJavaPlugins.setComputationSchedulerHandler { Schedulers.trampoline() } + } + + /** + * Reset RxJava plugins after each test. + */ + @After + fun tearDown() { + RxJavaPlugins.reset() + } + + /** + * Test that subscription triggers device discovery on all strategies. + */ + @Test + fun testSubscriptionTriggersDiscovery() { + val latch = CountDownLatch(1) + + every { mockStrategy1.discoverDevices(any()) } answers { + // Do nothing, just capture the callback + } + every { mockStrategy2.discoverDevices(any()) } answers { + // Do nothing, just capture the callback + } + + val observable = createObservableWithMockStrategies() + + observable + .subscribeOn(Schedulers.trampoline()) + .observeOn(Schedulers.trampoline()) + .subscribe() + + // Give time for async operations + latch.await(500, TimeUnit.MILLISECONDS) + + // Verify both strategies were called + verify { mockStrategy1.discoverDevices(any()) } + verify { mockStrategy2.discoverDevices(any()) } + } + + /** + * Test that discovered devices are emitted to the observer. + */ + @Test + fun testDiscoveredDevicesAreEmitted() { + val computer1 = ComputerParcelable("Server1", "192.168.1.100") + val computer2 = ComputerParcelable("Server2", "192.168.1.101") + + every { mockStrategy1.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer1) + } + every { mockStrategy2.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer2) + } + + val observable = createObservableWithMockStrategies() + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Verify both strategies were called and emitted devices + verify { mockStrategy1.discoverDevices(any()) } + verify { mockStrategy2.discoverDevices(any()) } + } + + /** + * Test that multiple devices from the same strategy are emitted. + * This test verifies that strategy.discoverDevices is called and + * that the observable can receive callbacks. + */ + @Test + fun testMultipleDevicesFromSameStrategy() { + val computer1 = ComputerParcelable("Server1", "192.168.1.100") + + every { mockStrategy1.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer1) + } + + // Create observable with single strategy + val observable = SmbDeviceScannerObservable() + val field = SmbDeviceScannerObservable::class.java.getDeclaredField("discoverDeviceStrategies") + field.isAccessible = true + field.set(observable, arrayOf(mockStrategy1)) + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Verify that strategy was called + verify { mockStrategy1.discoverDevices(any()) } + } + + /** + * Test that stop() calls onCancel on all strategies. + */ + @Test + fun testStopCallsOnCancelOnAllStrategies() { + val latch = CountDownLatch(1) + + every { mockStrategy1.discoverDevices(any()) } answers { + // Do nothing + } + every { mockStrategy2.discoverDevices(any()) } answers { + // Do nothing + } + + val observable = createObservableWithMockStrategies() + + observable + .subscribeOn(Schedulers.trampoline()) + .observeOn(Schedulers.trampoline()) + .subscribe() + + // Wait a bit for subscription to complete + latch.await(200, TimeUnit.MILLISECONDS) + + observable.stop() + + // Verify onCancel was called on both strategies + verify { mockStrategy1.onCancel() } + verify { mockStrategy2.onCancel() } + } + + /** + * Test that errors from strategies are propagated to the observer. + */ + @Test + fun testErrorsArePropagated() { + val latch = CountDownLatch(1) + val errorReceived = AtomicBoolean(false) + val testException = RuntimeException("Test error") + + every { mockStrategy1.discoverDevices(any()) } throws testException + every { mockStrategy2.discoverDevices(any()) } answers { + // Do nothing + } + + val observable = createObservableWithMockStrategies() + + observable + .subscribeOn(Schedulers.trampoline()) + .observeOn(Schedulers.trampoline()) + .subscribe( + { }, + { error -> + if (error is RuntimeException && error.message == "Test error") { + errorReceived.set(true) + } + latch.countDown() + }, + ) + + latch.await(1, TimeUnit.SECONDS) + + assertTrue("Error should be propagated to observer", errorReceived.get()) + } + + /** + * Test that discovery continues even if one strategy doesn't find anything. + */ + @Test + fun testDiscoveryContinuesIfOneStrategyFindsNothing() { + val computer = ComputerParcelable("Server1", "192.168.1.100") + + every { mockStrategy1.discoverDevices(any()) } answers { + // Strategy 1 finds nothing + } + every { mockStrategy2.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer) + } + + val observable = createObservableWithMockStrategies() + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Verify both strategies were called + verify { mockStrategy1.discoverDevices(any()) } + verify { mockStrategy2.discoverDevices(any()) } + } + + /** + * Test that duplicate devices (same address) can be emitted (filtering is done by consumer). + */ + @Test + fun testDuplicateDevicesAreEmitted() { + val computer1 = ComputerParcelable("Server1", "192.168.1.100") + val computer2 = ComputerParcelable("Server1", "192.168.1.100") // Same as computer1 + + every { mockStrategy1.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer1) + } + every { mockStrategy2.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer2) + } + + val observable = createObservableWithMockStrategies() + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Verify both strategies were called (both emit - consumer is responsible for deduplication) + verify { mockStrategy1.discoverDevices(any()) } + verify { mockStrategy2.discoverDevices(any()) } + } + + /** + * Test that observable works with a single strategy. + */ + @Test + fun testSingleStrategy() { + val computer = ComputerParcelable("Server1", "192.168.1.100") + + every { mockStrategy1.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer) + } + + val observable = SmbDeviceScannerObservable() + // Use reflection to set single strategy + val field = SmbDeviceScannerObservable::class.java.getDeclaredField("discoverDeviceStrategies") + field.isAccessible = true + field.set(observable, arrayOf(mockStrategy1)) + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Verify that strategy was called + verify { mockStrategy1.discoverDevices(any()) } + } + + /** + * Test that callbacks are invoked even if emitter is not disposed. + */ + @Test + fun testCallbacksInvokedWhenNotDisposed() { + val callbackInvoked = AtomicBoolean(false) + + val computer = ComputerParcelable("Server1", "192.168.1.100") + + every { mockStrategy1.discoverDevices(any()) } answers { + val callback = firstArg<(ComputerParcelable) -> Unit>() + callback(computer) + callbackInvoked.set(true) + } + every { mockStrategy2.discoverDevices(any()) } answers { + // Do nothing + } + + val observable = createObservableWithMockStrategies() + + observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + assertTrue("Callback should have been invoked", callbackInvoked.get()) + } + + /** + * Test that disposing stops receiving new devices. + */ + @Test + fun testDisposingStopsReceivingDevices() { + every { mockStrategy1.discoverDevices(any()) } answers { + // Do nothing - just verify we can subscribe and dispose + } + every { mockStrategy2.discoverDevices(any()) } answers { + // Do nothing + } + + val observable = createObservableWithMockStrategies() + + val disposable = observable.subscribe() + + // Process any pending runnables + shadowOf(Looper.getMainLooper()).idle() + + // Dispose + disposable.dispose() + + // Verify subscription was created and disposed + assertTrue("Subscription should be disposed", disposable.isDisposed) + } + + /** + * Helper method to create an observable with mock strategies. + */ + private fun createObservableWithMockStrategies(): SmbDeviceScannerObservable { + val observable = SmbDeviceScannerObservable() + + // Use reflection to replace the strategies with our mocks + val field = SmbDeviceScannerObservable::class.java.getDeclaredField("discoverDeviceStrategies") + field.isAccessible = true + field.set(observable, arrayOf(mockStrategy1, mockStrategy2)) + + return observable + } +}