Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions server/src/main/java/com/cloud/api/query/QueryManagerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -3813,11 +3812,62 @@ else if (!template.isPublicTemplate() && caller.getType() != Account.Type.ADMIN)
}
}

applyPublicTemplateSharingRestrictions(sc, caller);

return templateChecks(isIso, hypers, tags, name, keyword, hyperType, onlyReady, bootable, zoneId, showDomr, caller,
showRemovedTmpl, parentTemplateId, showUnique, searchFilter, sc);

}

/**
* If the caller is not a root admin, restricts the search to return only public templates from the domain which
* the caller belongs to and domains with the setting 'share.public.templates.with.other.domains' enabled.
*/
protected void applyPublicTemplateSharingRestrictions(SearchCriteria<TemplateJoinVO> sc, Account caller) {
if (caller.getType() == Account.Type.ADMIN) {
s_logger.debug(String.format("Account [%s] is a root admin. Therefore, it has access to all public templates.", caller));
return;
}

List<TemplateJoinVO> publicTemplates = _templateJoinDao.listPublicTemplates();

Set<Long> unsharableDomainIds = new HashSet<>();
for (TemplateJoinVO template : publicTemplates) {
addDomainIdToSetIfDomainDoesNotShareTemplates(template.getDomainId(), caller, unsharableDomainIds);
}

if (!unsharableDomainIds.isEmpty()) {
s_logger.info(String.format("The public templates belonging to the domains [%s] will not be listed to account [%s] as they have the configuration [%s] marked as 'false'.", unsharableDomainIds, caller, QueryService.SharePublicTemplatesWithOtherDomains.key()));
sc.addAnd("domainId", SearchCriteria.Op.NOTIN, unsharableDomainIds.toArray());
}
}

/**
* Adds the provided domain ID to the set if the domain does not share templates with the account. That is, if:
* (1) the template does not belong to the domain of the account AND
* (2) the domain of the template has the setting 'share.public.templates.with.other.domains' disabled.
*/
protected void addDomainIdToSetIfDomainDoesNotShareTemplates(long domainId, Account account, Set<Long> unsharableDomainIds) {
if (domainId == account.getDomainId()) {
s_logger.trace(String.format("Domain [%s] will not be added to the set of domains with unshared templates since the account [%s] belongs to it.", domainId, account));
return;
}

if (unsharableDomainIds.contains(domainId)) {
s_logger.trace(String.format("Domain [%s] is already on the set of domains with unshared templates.", domainId));
return;
}

if (!checkIfDomainSharesTemplates(domainId)) {
s_logger.debug(String.format("Domain [%s] will be added to the set of domains with unshared templates as configuration [%s] is false.", domainId, QueryService.SharePublicTemplatesWithOtherDomains.key()));
unsharableDomainIds.add(domainId);
}
}

protected boolean checkIfDomainSharesTemplates(Long domainId) {
return QueryService.SharePublicTemplatesWithOtherDomains.valueIn(domainId);
}

private Pair<List<TemplateJoinVO>, Integer> templateChecks(boolean isIso, List<HypervisorType> hypers, Map<String, String> tags, String name, String keyword,
HypervisorType hyperType, boolean onlyReady, Boolean bootable, Long zoneId, boolean showDomr, Account caller,
boolean showRemovedTmpl, Long parentTemplateId, Boolean showUnique,
Expand Down Expand Up @@ -3947,27 +3997,9 @@ private Pair<List<TemplateJoinVO>, Integer> findTemplatesByIdOrTempZonePair(Pair
templates = _templateJoinDao.searchByTemplateZonePair(showRemoved, templateZonePairs);
}

if(caller.getType() != Account.Type.ADMIN) {
templates = applyPublicTemplateRestriction(templates, caller);
count = templates.size();
}

return new Pair<List<TemplateJoinVO>, Integer>(templates, count);
}

private List<TemplateJoinVO> applyPublicTemplateRestriction(List<TemplateJoinVO> templates, Account caller){
List<Long> unsharableDomainIds = templates.stream()
.map(TemplateJoinVO::getDomainId)
.distinct()
.filter(domainId -> domainId != caller.getDomainId())
.filter(Predicate.not(QueryService.SharePublicTemplatesWithOtherDomains::valueIn))
.collect(Collectors.toList());

return templates.stream()
.filter(Predicate.not(t -> unsharableDomainIds.contains(t.getDomainId())))
.collect(Collectors.toList());
}

@Override
public ListResponse<TemplateResponse> listIsos(ListIsosCmd cmd) {
Pair<List<TemplateJoinVO>, Integer> result = searchForIsosInternal(cmd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public interface TemplateJoinDao extends GenericDao<TemplateJoinVO, Long> {

List<TemplateJoinVO> listActiveTemplates(long storeId);

List<TemplateJoinVO> listPublicTemplates();

Pair<List<TemplateJoinVO>, Integer> searchIncludingRemovedAndCount(final SearchCriteria<TemplateJoinVO> sc, final Filter filter);

List<TemplateJoinVO> findByDistinctIds(Long... ids);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ public class TemplateJoinDaoImpl extends GenericDaoBaseWithTagInformation<Templa

private final SearchBuilder<TemplateJoinVO> activeTmpltSearch;

private final SearchBuilder<TemplateJoinVO> publicTmpltSearch;

protected TemplateJoinDaoImpl() {

tmpltIdPairSearch = createSearchBuilder();
Expand Down Expand Up @@ -137,6 +139,10 @@ protected TemplateJoinDaoImpl() {
activeTmpltSearch.cp();
activeTmpltSearch.done();

publicTmpltSearch = createSearchBuilder();
publicTmpltSearch.and("public", publicTmpltSearch.entity().isPublicTemplate(), SearchCriteria.Op.EQ);
publicTmpltSearch.done();

// select distinct pair (template_id, zone_id)
_count = "select count(distinct temp_zone_pair) from template_view WHERE ";
}
Expand Down Expand Up @@ -572,6 +578,13 @@ public List<TemplateJoinVO> listActiveTemplates(long storeId) {
return searchIncludingRemoved(sc, null, null, false);
}

@Override
public List<TemplateJoinVO> listPublicTemplates() {
SearchCriteria<TemplateJoinVO> sc = publicTmpltSearch.create();
sc.setParameters("public", Boolean.TRUE);
return listBy(sc);
}

@Override
public Pair<List<TemplateJoinVO>, Integer> searchIncludingRemovedAndCount(final SearchCriteria<TemplateJoinVO> sc, final Filter filter) {
List<TemplateJoinVO> objects = searchIncludingRemoved(sc, filter, null, false);
Expand Down
83 changes: 83 additions & 0 deletions server/src/test/java/com/cloud/api/query/QueryManagerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package com.cloud.api.query;

import com.cloud.api.query.dao.TemplateJoinDao;
import com.cloud.api.query.vo.EventJoinVO;
import com.cloud.api.query.vo.TemplateJoinVO;
import com.cloud.event.dao.EventJoinDao;
import com.cloud.exception.InvalidParameterValueException;
import com.cloud.exception.PermissionDeniedException;
Expand Down Expand Up @@ -48,10 +50,13 @@
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.Spy;
import org.mockito.junit.MockitoJUnitRunner;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import static org.mockito.Mockito.when;
Expand All @@ -61,13 +66,28 @@ public class QueryManagerImplTest {
public static final long USER_ID = 1;
public static final long ACCOUNT_ID = 1;

@Spy
@InjectMocks
private QueryManagerImpl queryManagerImplSpy = new QueryManagerImpl();

@Mock
EntityManager entityManager;

@Mock
AccountManager accountManager;

@Mock
EventJoinDao eventJoinDao;

@Mock
Account accountMock;

@Mock
TemplateJoinDao templateJoinDaoMock;

@Mock
SearchCriteria searchCriteriaMock;

private AccountVO account;
private UserVO user;

Expand Down Expand Up @@ -176,4 +196,67 @@ public void searchForEventsFailPermissionDenied() {
Mockito.doThrow(new PermissionDeniedException("Denied")).when(accountManager).checkAccess(account, SecurityChecker.AccessType.ListEntry, false, network);
queryManager.searchForEvents(cmd);
}

@Test
public void applyPublicTemplateRestrictionsTestDoesNotApplyRestrictionsWhenCallerIsRootAdmin() {
Mockito.when(accountMock.getType()).thenReturn(Account.Type.ADMIN);

queryManagerImplSpy.applyPublicTemplateSharingRestrictions(searchCriteriaMock, accountMock);

Mockito.verify(searchCriteriaMock, Mockito.never()).addAnd(Mockito.anyString(), Mockito.any(), Mockito.any());
}

@Test
public void applyPublicTemplateRestrictionsTestAppliesRestrictionsWhenCallerIsNotRootAdmin() {
long callerDomainId = 1L;
long sharableDomainId = 2L;
long unsharableDomainId = 3L;

Mockito.when(accountMock.getType()).thenReturn(Account.Type.NORMAL);

Mockito.when(accountMock.getDomainId()).thenReturn(callerDomainId);
TemplateJoinVO templateMock1 = Mockito.mock(TemplateJoinVO.class);
Mockito.when(templateMock1.getDomainId()).thenReturn(callerDomainId);
Mockito.lenient().doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(callerDomainId);

TemplateJoinVO templateMock2 = Mockito.mock(TemplateJoinVO.class);
Mockito.when(templateMock2.getDomainId()).thenReturn(sharableDomainId);
Mockito.doReturn(true).when(queryManagerImplSpy).checkIfDomainSharesTemplates(sharableDomainId);

TemplateJoinVO templateMock3 = Mockito.mock(TemplateJoinVO.class);
Mockito.when(templateMock3.getDomainId()).thenReturn(unsharableDomainId);
Mockito.doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(unsharableDomainId);

List<TemplateJoinVO> publicTemplates = List.of(templateMock1, templateMock2, templateMock3);
Mockito.when(templateJoinDaoMock.listPublicTemplates()).thenReturn(publicTemplates);

queryManagerImplSpy.applyPublicTemplateSharingRestrictions(searchCriteriaMock, accountMock);

Mockito.verify(searchCriteriaMock).addAnd("domainId", SearchCriteria.Op.NOTIN, unsharableDomainId);
}

@Test
public void addDomainIdToSetIfDomainDoesNotShareTemplatesTestDoesNotAddWhenCallerBelongsToDomain() {
long domainId = 1L;
Set<Long> set = new HashSet<>();

Mockito.when(accountMock.getDomainId()).thenReturn(domainId);

queryManagerImplSpy.addDomainIdToSetIfDomainDoesNotShareTemplates(domainId, accountMock, set);

Assert.assertEquals(0, set.size());
}

@Test
public void addDomainIdToSetIfDomainDoesNotShareTemplatesTestAddsWhenDomainDoesNotShareTemplates() {
long domainId = 1L;
Set<Long> set = new HashSet<>();

Mockito.when(accountMock.getDomainId()).thenReturn(2L);
Mockito.doReturn(false).when(queryManagerImplSpy).checkIfDomainSharesTemplates(domainId);

queryManagerImplSpy.addDomainIdToSetIfDomainDoesNotShareTemplates(domainId, accountMock, set);

Assert.assertTrue(set.contains(domainId));
}
}